--- a
+++ b/unimol/losses/unimol.py
@@ -0,0 +1,210 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from unicore import metrics
+from unicore.losses import UnicoreLoss, register_loss
+
+
+@register_loss("unimol")
+class UniMolLoss(UnicoreLoss):
+    def __init__(self, task):
+        super().__init__(task)
+        self.padding_idx = task.dictionary.pad()
+        self.seed = task.seed
+        self.dist_mean = 6.312581655060595
+        self.dist_std = 3.3899264663911888
+
+    def forward(self, model, sample, reduce=True):
+        input_key = "net_input"
+        target_key = "target"
+        masked_tokens = sample[target_key]["tokens_target"].ne(self.padding_idx)
+        sample_size = masked_tokens.long().sum()
+        (
+            logits_encoder,
+            encoder_distance,
+            encoder_coord,
+            x_norm,
+            delta_encoder_pair_rep_norm,
+        ) = model(**sample[input_key], encoder_masked_tokens=masked_tokens)
+        target = sample[target_key]["tokens_target"]
+        if masked_tokens is not None:
+            target = target[masked_tokens]
+        masked_token_loss = F.nll_loss(
+            F.log_softmax(logits_encoder, dim=-1, dtype=torch.float32),
+            target,
+            ignore_index=self.padding_idx,
+            reduction="mean",
+        )
+        masked_pred = logits_encoder.argmax(dim=-1)
+        masked_hit = (masked_pred == target).long().sum()
+        masked_cnt = sample_size
+        loss = masked_token_loss * self.args.masked_token_loss
+        logging_output = {
+            "sample_size": 1,
+            "bsz": sample[target_key]["tokens_target"].size(0),
+            "seq_len": sample[target_key]["tokens_target"].size(1)
+            * sample[target_key]["tokens_target"].size(0),
+            "masked_token_loss": masked_token_loss.data,
+            "masked_token_hit": masked_hit.data,
+            "masked_token_cnt": masked_cnt,
+        }
+
+        if encoder_coord is not None:
+            # real = mask + delta
+            coord_target = sample[target_key]["coord_target"]
+            masked_coord_loss = F.smooth_l1_loss(
+                encoder_coord[masked_tokens].view(-1, 3).float(),
+                coord_target[masked_tokens].view(-1, 3),
+                reduction="mean",
+                beta=1.0,
+            )
+            loss = loss + masked_coord_loss * self.args.masked_coord_loss
+            # restore the scale of loss for displaying
+            logging_output["masked_coord_loss"] = masked_coord_loss.data
+
+        if encoder_distance is not None:
+            dist_masked_tokens = masked_tokens
+            masked_dist_loss = self.cal_dist_loss(
+                sample, encoder_distance, dist_masked_tokens, target_key, normalize=True
+            )
+            loss = loss + masked_dist_loss * self.args.masked_dist_loss
+            logging_output["masked_dist_loss"] = masked_dist_loss.data
+
+        if self.args.x_norm_loss > 0 and x_norm is not None:
+            loss = loss + self.args.x_norm_loss * x_norm
+            logging_output["x_norm_loss"] = x_norm.data
+
+        if (
+            self.args.delta_pair_repr_norm_loss > 0
+            and delta_encoder_pair_rep_norm is not None
+        ):
+            loss = (
+                loss + self.args.delta_pair_repr_norm_loss * delta_encoder_pair_rep_norm
+            )
+            logging_output[
+                "delta_pair_repr_norm_loss"
+            ] = delta_encoder_pair_rep_norm.data
+
+        logging_output["loss"] = loss.data
+        return loss, 1, logging_output
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        bsz = sum(log.get("bsz", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        seq_len = sum(log.get("seq_len", 0) for log in logging_outputs)
+        metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
+        metrics.log_scalar("seq_len", seq_len / bsz, 1, round=3)
+
+        masked_loss = sum(log.get("masked_token_loss", 0) for log in logging_outputs)
+        metrics.log_scalar(
+            "masked_token_loss", masked_loss / sample_size, sample_size, round=3
+        )
+
+        masked_acc = sum(
+            log.get("masked_token_hit", 0) for log in logging_outputs
+        ) / sum(log.get("masked_token_cnt", 0) for log in logging_outputs)
+        metrics.log_scalar("masked_acc", masked_acc, sample_size, round=3)
+
+        masked_coord_loss = sum(
+            log.get("masked_coord_loss", 0) for log in logging_outputs
+        )
+        if masked_coord_loss > 0:
+            metrics.log_scalar(
+                "masked_coord_loss",
+                masked_coord_loss / sample_size,
+                sample_size,
+                round=3,
+            )
+
+        masked_dist_loss = sum(
+            log.get("masked_dist_loss", 0) for log in logging_outputs
+        )
+        if masked_dist_loss > 0:
+            metrics.log_scalar(
+                "masked_dist_loss", masked_dist_loss / sample_size, sample_size, round=3
+            )
+
+        x_norm_loss = sum(log.get("x_norm_loss", 0) for log in logging_outputs)
+        if x_norm_loss > 0:
+            metrics.log_scalar(
+                "x_norm_loss", x_norm_loss / sample_size, sample_size, round=3
+            )
+
+        delta_pair_repr_norm_loss = sum(
+            log.get("delta_pair_repr_norm_loss", 0) for log in logging_outputs
+        )
+        if delta_pair_repr_norm_loss > 0:
+            metrics.log_scalar(
+                "delta_pair_repr_norm_loss",
+                delta_pair_repr_norm_loss / sample_size,
+                sample_size,
+                round=3,
+            )
+
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return True
+
+    def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False):
+        dist_masked_tokens = masked_tokens
+        masked_distance = dist[dist_masked_tokens, :]
+        masked_distance_target = sample[target_key]["distance_target"][
+            dist_masked_tokens
+        ]
+        non_pad_pos = masked_distance_target > 0
+        if normalize:
+            masked_distance_target = (
+                masked_distance_target.float() - self.dist_mean
+            ) / self.dist_std
+        masked_dist_loss = F.smooth_l1_loss(
+            masked_distance[non_pad_pos].view(-1).float(),
+            masked_distance_target[non_pad_pos].view(-1),
+            reduction="mean",
+            beta=1.0,
+        )
+        return masked_dist_loss
+
+
+@register_loss("unimol_infer")
+class UniMolInferLoss(UnicoreLoss):
+    def __init__(self, task):
+        super().__init__(task)
+        self.padding_idx = task.dictionary.pad()
+
+    def forward(self, model, sample, reduce=True):
+        """Compute the loss for the given sample.
+
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        input_key = "net_input"
+        target_key = "target"
+        src_tokens = sample[input_key]["src_tokens"].ne(self.padding_idx)
+        (
+            encoder_rep,
+            encoder_pair_rep,
+        ) = model(**sample[input_key], features_only=True)
+        sample_size = sample[input_key]["src_tokens"].size(0)
+        encoder_pair_rep_list = []
+        for i in range(sample_size):  # rm padding token
+            encoder_pair_rep_list.append(encoder_pair_rep[i][src_tokens[i], :][:, src_tokens[i]].data.cpu().numpy())
+        logging_output = {
+                "mol_repr_cls": encoder_rep[:, 0, :].data.cpu().numpy(),  # get cls token
+                "pair_repr": encoder_pair_rep_list,
+                "smi_name": sample[target_key]["smi_name"],
+                "bsz": sample[input_key]["src_tokens"].size(0),
+            }
+        return 0, sample_size, logging_output