[b40915]: / unimol / losses / unimol.py

Download this file

211 lines (188 with data), 8.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from unicore import metrics
from unicore.losses import UnicoreLoss, register_loss
@register_loss("unimol")
class UniMolLoss(UnicoreLoss):
def __init__(self, task):
super().__init__(task)
self.padding_idx = task.dictionary.pad()
self.seed = task.seed
self.dist_mean = 6.312581655060595
self.dist_std = 3.3899264663911888
def forward(self, model, sample, reduce=True):
input_key = "net_input"
target_key = "target"
masked_tokens = sample[target_key]["tokens_target"].ne(self.padding_idx)
sample_size = masked_tokens.long().sum()
(
logits_encoder,
encoder_distance,
encoder_coord,
x_norm,
delta_encoder_pair_rep_norm,
) = model(**sample[input_key], encoder_masked_tokens=masked_tokens)
target = sample[target_key]["tokens_target"]
if masked_tokens is not None:
target = target[masked_tokens]
masked_token_loss = F.nll_loss(
F.log_softmax(logits_encoder, dim=-1, dtype=torch.float32),
target,
ignore_index=self.padding_idx,
reduction="mean",
)
masked_pred = logits_encoder.argmax(dim=-1)
masked_hit = (masked_pred == target).long().sum()
masked_cnt = sample_size
loss = masked_token_loss * self.args.masked_token_loss
logging_output = {
"sample_size": 1,
"bsz": sample[target_key]["tokens_target"].size(0),
"seq_len": sample[target_key]["tokens_target"].size(1)
* sample[target_key]["tokens_target"].size(0),
"masked_token_loss": masked_token_loss.data,
"masked_token_hit": masked_hit.data,
"masked_token_cnt": masked_cnt,
}
if encoder_coord is not None:
# real = mask + delta
coord_target = sample[target_key]["coord_target"]
masked_coord_loss = F.smooth_l1_loss(
encoder_coord[masked_tokens].view(-1, 3).float(),
coord_target[masked_tokens].view(-1, 3),
reduction="mean",
beta=1.0,
)
loss = loss + masked_coord_loss * self.args.masked_coord_loss
# restore the scale of loss for displaying
logging_output["masked_coord_loss"] = masked_coord_loss.data
if encoder_distance is not None:
dist_masked_tokens = masked_tokens
masked_dist_loss = self.cal_dist_loss(
sample, encoder_distance, dist_masked_tokens, target_key, normalize=True
)
loss = loss + masked_dist_loss * self.args.masked_dist_loss
logging_output["masked_dist_loss"] = masked_dist_loss.data
if self.args.x_norm_loss > 0 and x_norm is not None:
loss = loss + self.args.x_norm_loss * x_norm
logging_output["x_norm_loss"] = x_norm.data
if (
self.args.delta_pair_repr_norm_loss > 0
and delta_encoder_pair_rep_norm is not None
):
loss = (
loss + self.args.delta_pair_repr_norm_loss * delta_encoder_pair_rep_norm
)
logging_output[
"delta_pair_repr_norm_loss"
] = delta_encoder_pair_rep_norm.data
logging_output["loss"] = loss.data
return loss, 1, logging_output
@staticmethod
def reduce_metrics(logging_outputs, split="valid") -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
bsz = sum(log.get("bsz", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
seq_len = sum(log.get("seq_len", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar("seq_len", seq_len / bsz, 1, round=3)
masked_loss = sum(log.get("masked_token_loss", 0) for log in logging_outputs)
metrics.log_scalar(
"masked_token_loss", masked_loss / sample_size, sample_size, round=3
)
masked_acc = sum(
log.get("masked_token_hit", 0) for log in logging_outputs
) / sum(log.get("masked_token_cnt", 0) for log in logging_outputs)
metrics.log_scalar("masked_acc", masked_acc, sample_size, round=3)
masked_coord_loss = sum(
log.get("masked_coord_loss", 0) for log in logging_outputs
)
if masked_coord_loss > 0:
metrics.log_scalar(
"masked_coord_loss",
masked_coord_loss / sample_size,
sample_size,
round=3,
)
masked_dist_loss = sum(
log.get("masked_dist_loss", 0) for log in logging_outputs
)
if masked_dist_loss > 0:
metrics.log_scalar(
"masked_dist_loss", masked_dist_loss / sample_size, sample_size, round=3
)
x_norm_loss = sum(log.get("x_norm_loss", 0) for log in logging_outputs)
if x_norm_loss > 0:
metrics.log_scalar(
"x_norm_loss", x_norm_loss / sample_size, sample_size, round=3
)
delta_pair_repr_norm_loss = sum(
log.get("delta_pair_repr_norm_loss", 0) for log in logging_outputs
)
if delta_pair_repr_norm_loss > 0:
metrics.log_scalar(
"delta_pair_repr_norm_loss",
delta_pair_repr_norm_loss / sample_size,
sample_size,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed(is_train) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False):
dist_masked_tokens = masked_tokens
masked_distance = dist[dist_masked_tokens, :]
masked_distance_target = sample[target_key]["distance_target"][
dist_masked_tokens
]
non_pad_pos = masked_distance_target > 0
if normalize:
masked_distance_target = (
masked_distance_target.float() - self.dist_mean
) / self.dist_std
masked_dist_loss = F.smooth_l1_loss(
masked_distance[non_pad_pos].view(-1).float(),
masked_distance_target[non_pad_pos].view(-1),
reduction="mean",
beta=1.0,
)
return masked_dist_loss
@register_loss("unimol_infer")
class UniMolInferLoss(UnicoreLoss):
def __init__(self, task):
super().__init__(task)
self.padding_idx = task.dictionary.pad()
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
input_key = "net_input"
target_key = "target"
src_tokens = sample[input_key]["src_tokens"].ne(self.padding_idx)
(
encoder_rep,
encoder_pair_rep,
) = model(**sample[input_key], features_only=True)
sample_size = sample[input_key]["src_tokens"].size(0)
encoder_pair_rep_list = []
for i in range(sample_size): # rm padding token
encoder_pair_rep_list.append(encoder_pair_rep[i][src_tokens[i], :][:, src_tokens[i]].data.cpu().numpy())
logging_output = {
"mol_repr_cls": encoder_rep[:, 0, :].data.cpu().numpy(), # get cls token
"pair_repr": encoder_pair_rep_list,
"smi_name": sample[target_key]["smi_name"],
"bsz": sample[input_key]["src_tokens"].size(0),
}
return 0, sample_size, logging_output