--- a
+++ b/unimol/losses/cross_entropy.py
@@ -0,0 +1,946 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+import torch.nn.functional as F
+import pandas as pd
+from unicore import metrics
+from unicore.losses import UnicoreLoss, register_loss
+from unicore.losses.cross_entropy import CrossEntropyLoss
+from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
+import numpy as np
+import warnings
+from sklearn.metrics import top_k_accuracy_score
+from rdkit.ML.Scoring.Scoring import CalcBEDROC
+import scipy.stats as stats
+
+
+def calculate_bedroc(y_true, y_score, alpha):
+    """
+    Calculate BEDROC score.
+
+    Parameters:
+    - y_true: true binary labels (0 or 1)
+    - y_score: predicted scores or probabilities
+    - alpha: parameter controlling the degree of early retrieval emphasis
+
+    Returns:
+    - BEDROC score
+    """
+    
+        # concate res_single and labels
+    scores = np.expand_dims(y_score, axis=1)
+    y_true = np.expand_dims(y_true, axis=1)
+    #print(scores.shape, y_true.shape)
+    scores = np.concatenate((scores, y_true), axis=1)
+    # inverse sort scores based on first column
+    scores = scores[scores[:,0].argsort()[::-1]]
+    bedroc = CalcBEDROC(scores, 1, 80.5)
+    return bedroc
+
+@register_loss("decoder_loss")
+class DecoderLoss(CrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, reduce=True, fix_encoder=False):
+        """Compute the loss for the given sample.
+
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            features_only=True,
+            classification_head_name=None,
+            fix_encoder=fix_encoder
+        )
+        loss = self.compute_loss(model, net_output, sample, reduce=reduce)
+        targets = sample["net_input"]["selfie_tokens"]
+        sample_size = targets.size(0)
+        
+        lprobs = net_output[:,:,:targets.shape[-1]]
+        if not self.training:
+            logging_output = {
+                "loss": loss.data,
+                "prob": lprobs.data,
+                "target": targets.data,
+                "smi_name": sample["smi_name"],
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        return loss, sample_size, logging_output
+
+    def compute_loss(self, model, net_output, sample, reduce=True):
+        lprobs = net_output
+        targets = sample["net_input"]["selfie_tokens"]
+        lprobs = lprobs[:,:,:targets.shape[-1]]
+        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
+        nll_loss = F.nll_loss(
+            lprobs,
+            targets,
+            reduction="sum" if reduce else "none",
+        ) / lprobs.shape[-1]
+
+        loss =  nll_loss 
+        #print(loss.data, nll_loss.data, kld_loss.data)
+        return loss
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        # if split == "valid":
+        #     print("hi1")
+        loss = sum(log.get("loss", 0).float() for log in logging_outputs)
+        # if split == "valid":
+        #     print("hi2")
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # if split == "valid":
+        #     print("hi3")
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss_all", loss / sample_size / math.log(2), sample_size, round=3
+        )
+
+        if "valid" in split or "test" in split:
+            
+            prob_list = []
+            pred_list = []
+            target_list = []
+            for log in logging_outputs:
+                prob = log.get("prob")
+                prob = torch.transpose(prob, 1, 2)
+                prob = prob.reshape((-1, prob.shape[-1]))
+                prob_list.append(prob)
+                pred = log.get("prob").argmax(dim=1)
+                pred = pred.flatten()
+                pred_list.append(pred)
+                target = log.get("target")
+                target = target.flatten()
+                target_list.append(target)
+
+            preds = torch.cat(pred_list, dim=0)
+            targets = torch.cat(target_list, dim=0)
+            #print(preds.shape, targets.shape)
+            acc = (preds == targets).float().mean(dim=-1)
+            #print(acc.shape)
+            metrics.log_scalar(
+                f"{split}_acc", acc , sample_size, round=3
+            )
+            
+        
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return is_train
+
+
+@register_loss("decoder_vae_loss")
+class DecoderVAELoss(CrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, candidate_reps, candidate_embs, candidate_smiles, reduce=True, fix_encoder=False):
+        """Compute the loss for the given sample.
+
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            candidate_reps=candidate_reps,
+            candidate_embs=candidate_embs,
+            candidate_smiles=candidate_smiles,
+            features_only=True,
+            classification_head_name=None,
+            fix_encoder=fix_encoder
+        )
+        loss, nll_loss, kld_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
+        targets = sample["net_input"]["selfie_tokens"]
+        sample_size = targets.size(0)
+        
+        lprobs = net_output[0][:,:,:targets.shape[-1]]
+        if not self.training:
+            logging_output = {
+                "loss": loss.data,
+                "kld_loss": kld_loss.data,
+                "nll_loss": nll_loss.data,
+                "prob": lprobs.data,
+                "target": targets.data,
+                "smi_name": sample["smi_name"],
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "kld_loss": kld_loss.data,
+                "nll_loss": nll_loss.data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        return loss, sample_size, logging_output
+
+    def compute_loss(self, model, net_output, sample, reduce=True):
+        out, z, mu, log_var = net_output
+        lprobs = out
+        targets = sample["net_input"]["selfie_tokens"]
+        lprobs = lprobs[:,:,:targets.shape[-1]]
+        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
+        nll_loss = F.nll_loss(
+            lprobs,
+            targets,
+            reduction="sum" if reduce else "none",
+        ) / lprobs.shape[-1]
+        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / lprobs.shape[1]
+        p=0.2
+        loss = p * kld_loss + (1-p) * nll_loss 
+        #print(loss.data, nll_loss.data, kld_loss.data)
+        return loss, nll_loss, kld_loss
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        # if split == "valid":
+        #     print("hi1")
+        loss_all = sum(log.get("loss", 0).float() for log in logging_outputs)
+        loss_kld = sum(log.get("kld_loss", 0).float() for log in logging_outputs)
+        loss_nll = sum(log.get("nll_loss", 0).float() for log in logging_outputs)
+        # if split == "valid":
+        #     print("hi2")
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # if split == "valid":
+        #     print("hi3")
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss_all", loss_all / sample_size / math.log(2), sample_size, round=3
+        )
+        metrics.log_scalar(
+            "loss_kld", loss_kld / sample_size, sample_size, round=3
+        )
+        metrics.log_scalar(
+            "loss_nll", loss_nll / sample_size, sample_size, round=3
+        )
+        if "valid" in split or "test" in split:
+            
+            prob_list = []
+            pred_list = []
+            target_list = []
+            for log in logging_outputs:
+                prob = log.get("prob")
+                prob = torch.transpose(prob, 1, 2)
+                prob = prob.reshape((-1, prob.shape[-1]))
+                prob_list.append(prob)
+                pred = log.get("prob").argmax(dim=1)
+                pred = pred.flatten()
+                pred_list.append(pred)
+                target = log.get("target")
+                target = target.flatten()
+                target_list.append(target)
+
+            probs = torch.cat(prob_list, dim=0)
+            preds = torch.cat(pred_list, dim=0)
+            targets = torch.cat(target_list, dim=0)
+            #print(preds.shape, targets.shape)
+            acc = (preds == targets).float().mean(dim=-1)
+            #print(acc.shape)
+            metrics.log_scalar(
+                f"{split}_acc", acc , sample_size, round=3
+            )
+            '''
+            # smi_list = [
+            #     item for log in logging_outputs for item in log.get("smi_name")
+            # ]
+            probs = torch.exp(probs)
+            #prob_flat = prob_flat.reshape((-1, prob_flat.shape[-1]))
+            print(probs.shape)
+
+            #targets = targets.squeeze(dim=-1)
+            auc = roc_auc_score(targets.cpu(), probs.cpu(), multi_class="ovo", labels=torch.arange(probs.shape[-1]))
+            #df = df.groupby("smi").mean()
+            #agg_auc = roc_auc_score(df["targets"], df["probs"])
+            agg_auc = auc
+            
+            metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
+            metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)
+            '''
+        
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return is_train
+
+
+@register_loss("finetune_cross_entropy")
+class FinetuneCrossEntropyLoss(CrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, reduce=True, fix_encoder=False):
+        """Compute the loss for the given sample.
+
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            features_only=True,
+            classification_head_name=self.args.classification_head_name,
+            fix_encoder=fix_encoder
+        )
+        logit_output = net_output[0]
+        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
+        sample_size = sample["target"]["finetune_target"].size(0)
+        if not self.training:
+            probs = F.softmax(logit_output.float(), dim=-1).view(
+                -1, logit_output.size(-1)
+            )
+            logging_output = {
+                "loss": loss.data,
+                "prob": probs.data,
+                "target": sample["target"]["finetune_target"].view(-1).data,
+                "smi_name": sample["smi_name"],
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        return loss, sample_size, logging_output
+
+    def compute_loss(self, model, net_output, sample, reduce=True):
+        lprobs = F.log_softmax(net_output.float(), dim=-1)
+        lprobs = lprobs.view(-1, lprobs.size(-1))
+        targets = sample["target"]["finetune_target"].view(-1)
+        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
+        loss = F.nll_loss(
+            lprobs,
+            targets,
+            reduction="sum" if reduce else "none",
+        )
+        return loss
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+        )
+        if "valid" in split or "test" in split:
+            acc_sum = sum(
+                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
+                for log in logging_outputs
+            )
+            probs = torch.cat([log.get("prob") for log in logging_outputs], dim=0)
+            metrics.log_scalar(
+                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
+            )
+            if probs.size(-1) == 2:
+                # binary classification task, add auc score
+                targets = torch.cat(
+                    [log.get("target", 0) for log in logging_outputs], dim=0
+                )
+                smi_list = [
+                    item for log in logging_outputs for item in log.get("smi_name")
+                ]
+                df = pd.DataFrame(
+                    {
+                        "probs": probs[:, 1].cpu(),
+                        "targets": targets.cpu(),
+                        "smi": smi_list,
+                    }
+                )
+                auc = roc_auc_score(df["targets"], df["probs"])
+                df = df.groupby("smi").mean()
+                agg_auc = roc_auc_score(df["targets"], df["probs"])
+                metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
+                metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)
+
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return is_train
+
+@register_loss("ce")
+class CEntropyLoss(CrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, reduce=True, fix_encoder=False):
+        """Compute the loss for the given sample.
+
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            smi_list = sample["smi_name"],
+            pocket_list = sample["pocket_name"],
+            features_only=True,
+            fix_encoder=fix_encoder
+        )
+        logit_output = net_output
+        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
+        #print(sample["target"]["finetune_target"])
+        sample_size = sample["target"]["finetune_target"].size(0)
+        if not self.training:
+            probs = torch.sigmoid(logit_output.float())
+            logging_output = {
+                "loss": loss.data,
+                "prob": probs.data,
+                "target": sample["target"]["finetune_target"].view(-1).data,
+                "smi_name": sample["smi_name"],
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        return loss, sample_size, logging_output
+
+    def compute_loss(self, model, net_output, sample, reduce=True):
+
+        targets = sample["target"]["finetune_target"].view(-1)
+        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
+        #print(net_output.shape, targets.shape)
+        loss = F.binary_cross_entropy_with_logits(
+            net_output.float(),
+            targets,
+            reduction="sum" if reduce else "none",
+        )
+        return loss
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+        )
+        if "valid" in split or "test" in split:
+
+            # get acc
+            acc_sum = 0
+            for log in logging_outputs:
+                probs = torch.sigmoid(log.get("prob"))
+                targets = log.get("target")
+                probs = probs > 0.5
+                # convert to int
+                probs = probs.long()
+            
+            
+            logs = [log.get("prob") for log in logging_outputs]
+            targets = torch.cat(
+                [log.get("target", 0) for log in logging_outputs], dim=0
+            )
+            probs = torch.cat([log.get("prob") for log in logging_outputs], dim=0)
+            #probs = torch.sigmoid(probs)
+            print(probs.shape, targets.shape)
+            print(probs[:10], targets[:10])
+            preds = probs > 0.5
+            # convert to int
+            preds = preds.long()
+            acc_sum = (preds == targets).sum()
+            metrics.log_scalar(
+                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
+            )
+            # binary classification task, add auc score
+            
+            
+            smi_list = [
+                item for log in logging_outputs for item in log.get("smi_name")
+            ]
+            df = pd.DataFrame(
+                {
+                    "probs": probs.cpu(),
+                    "targets": targets.cpu(),
+                    "smi": smi_list,
+                }
+            )
+            # get values of df["targets"]
+            # 
+            
+            auc = roc_auc_score(df["targets"].values, df["probs"].values)
+            df = df.groupby("smi").mean()
+            agg_auc = roc_auc_score(df["targets"], df["probs"])
+            #print(df["targets"])
+            metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
+            metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)
+
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return is_train
+
+
+@register_loss("in_batch_softmax")
+class IBSLoss(CrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, reduce=True, fix_encoder=False):
+        """Compute the loss for the given sample.
+
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            smi_list = sample["smi_name"],
+            pocket_list = sample["pocket_name"],
+            features_only=True,
+            fix_encoder=fix_encoder,
+            is_train = self.training
+        )
+        
+        logit_output = net_output[0]
+        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
+        sample_size = logit_output.size(0)
+        targets = torch.arange(sample_size, dtype=torch.long).cuda()
+        affinities = sample["target"]["finetune_target"].view(-1)
+        if not self.training:
+            logit_output = logit_output[:,:sample_size]
+            probs = F.softmax(logit_output.float(), dim=-1).view(
+                -1, logit_output.size(-1)
+            )
+            logging_output = {
+                "loss": loss.data,
+                "prob": probs.data,
+                "target": targets,
+                "smi_name": sample["smi_name"],
+                "sample_size": sample_size,
+                "bsz": targets.size(0),
+                "scale": net_output[1].data,
+                "affinity": affinities,
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "sample_size": sample_size,
+                "bsz": targets.size(0),
+                "scale": net_output[1].data
+            }
+        return loss, sample_size, logging_output
+
+    def compute_loss(self, model, net_output, sample, reduce=True):
+        lprobs_pocket = F.log_softmax(net_output.float(), dim=-1)
+        lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
+        sample_size = lprobs_pocket.size(0)
+        targets= torch.arange(sample_size, dtype=torch.long).view(-1).cuda()
+
+        # pocket retrieve mol
+        loss_pocket = F.nll_loss(
+            lprobs_pocket,
+            targets,
+            reduction="sum" if reduce else "none",
+        )
+        
+        lprobs_mol = F.log_softmax(torch.transpose(net_output.float(), 0, 1), dim=-1)
+        lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))
+        lprobs_mol = lprobs_mol[:sample_size]
+
+        # mol retrieve pocket
+        loss_mol = F.nll_loss(
+            lprobs_mol,
+            targets,
+            reduction="sum" if reduce else "none",
+        )
+        
+        loss = 0.5 * loss_pocket + 0.5 * loss_mol
+        return loss
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        metrics.log_scalar("scale", logging_outputs[0].get("scale"), round=3)
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+        )
+        if "valid" in split or "test" in split:
+            acc_sum = sum(
+                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
+                for log in logging_outputs
+            )
+            
+            prob_list = []
+            if len(logging_outputs) == 1:
+                prob_list.append(logging_outputs[0].get("prob"))
+            else:
+                for i in range(len(logging_outputs)-1):
+                    prob_list.append(logging_outputs[i].get("prob"))
+            probs = torch.cat(prob_list, dim=0)
+            
+            metrics.log_scalar(
+                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
+            )
+
+            metrics.log_scalar(
+                "valid_neg_loss", -loss_sum / sample_size / math.log(2), sample_size, round=3
+            )
+            
+            
+            targets = torch.cat(
+                [log.get("target", 0) for log in logging_outputs], dim=0
+            )
+            print(targets.shape, probs.shape)
+
+            targets = targets[:len(probs)]
+            bedroc_list = []
+            auc_list = []
+            for i in range(len(probs)):
+                prob = probs[i]
+                target = targets[i]
+                label = torch.zeros_like(prob)
+                label[target]=1.0
+                cur_auc = roc_auc_score(label.cpu(), prob.cpu())
+                auc_list.append(cur_auc)
+                bedroc = calculate_bedroc(label.cpu(), prob.cpu(), 80.5)
+                bedroc_list.append(bedroc)
+            bedroc = np.mean(bedroc_list)
+            auc = np.mean(auc_list)
+            
+            top_k_acc = top_k_accuracy_score(targets.cpu(), probs.cpu(), k=3, normalize=True)
+            metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
+            metrics.log_scalar(f"{split}_bedroc", bedroc, sample_size, round=3)
+            metrics.log_scalar(f"{split}_top3_acc", top_k_acc, sample_size, round=3)
+
+            
+
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return is_train
+
+
+
+
+
+
+@register_loss("multi_task_BCE")
+class MultiTaskBCELoss(CrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, reduce=True, fix_encoder=False):
+        """Compute the loss for the given sample.
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            masked_tokens=None,
+            features_only=True,
+            classification_head_name=self.args.classification_head_name,
+            fix_encoder=fix_encoder
+        )
+        logit_output = net_output[0]
+        is_valid = sample["target"]["finetune_target"] > -0.5
+        loss = self.compute_loss(
+            model, logit_output, sample, reduce=reduce, is_valid=is_valid
+        )
+        sample_size = sample["target"]["finetune_target"].size(0)
+        if not self.training:
+            probs = torch.sigmoid(logit_output.float()).view(-1, logit_output.size(-1))
+            logging_output = {
+                "loss": loss.data,
+                "prob": probs.data,
+                "target": sample["target"]["finetune_target"].view(-1).data,
+                "num_task": self.args.num_classes,
+                "sample_size": sample_size,
+                "conf_size": self.args.conf_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        return loss, sample_size, logging_output
+
+    def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None):
+        pred = net_output[is_valid].float()
+        targets = sample["target"]["finetune_target"][is_valid].float()
+        loss = F.binary_cross_entropy_with_logits(
+            pred,
+            targets,
+            reduction="sum" if reduce else "none",
+        )
+        return loss
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+        )
+        if "valid" in split or "test" in split:
+            agg_auc_list = []
+            num_task = logging_outputs[0].get("num_task", 0)
+            conf_size = logging_outputs[0].get("conf_size", 0)
+            y_true = (
+                torch.cat([log.get("target", 0) for log in logging_outputs], dim=0)
+                .view(-1, conf_size, num_task)
+                .cpu()
+                .numpy()
+                .mean(axis=1)
+            )
+            y_pred = (
+                torch.cat([log.get("prob") for log in logging_outputs], dim=0)
+                .view(-1, conf_size, num_task)
+                .cpu()
+                .numpy()
+                .mean(axis=1)
+            )
+            # [test_size, num_classes] = [test_size * conf_size, num_classes].mean(axis=1)
+            for i in range(y_true.shape[1]):
+                # AUC is only defined when there is at least one positive data.
+                if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
+                    # ignore nan values
+                    is_labeled = y_true[:, i] > -0.5
+                    agg_auc_list.append(
+                        roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])
+                    )
+            if len(agg_auc_list) < y_true.shape[1]:
+                warnings.warn("Some target is missing!")
+            if len(agg_auc_list) == 0:
+                raise RuntimeError(
+                    "No positively labeled data available. Cannot compute Average Precision."
+                )
+            agg_auc = sum(agg_auc_list) / len(agg_auc_list)
+            metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)
+
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return is_train
+
+@register_loss("BCE")
+class BCELoss(CrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, reduce=True, fix_encoder=False):
+        """Compute the loss for the given sample.
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            smi_list = sample["smi_name"],
+            pocket_list = sample["pocket_name"],
+            features_only=True,
+            fix_encoder=fix_encoder
+        )
+        logit_output = net_output
+        loss = self.compute_loss(
+            model, logit_output, sample, reduce=reduce
+        )
+        sample_size = sample["target"]["finetune_target"].size(0)
+
+        if not self.training:
+            probs = torch.sigmoid(logit_output.float())
+            #print(probs.size())
+            logging_output = {
+                "loss": loss.data,
+                "prob": probs.data,
+                "target": sample["target"]["finetune_target"].view(-1).data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        return loss, sample_size, logging_output
+
+    def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None):
+        pred = net_output.float()
+        targets = sample["target"]["finetune_target"].float()
+        loss = F.binary_cross_entropy_with_logits(
+            pred,
+            targets,
+            reduction="sum" if reduce else "none",
+        )
+        return loss
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+        )
+        if "valid" in split or "test" in split:
+            y_true_list = []
+            y_pred_list = []
+            y_true_list = [log.get("target", 0) for log in logging_outputs]
+            y_pred_list = [log.get("prob") for log in logging_outputs]
+            y_true = (
+                torch.cat(y_true_list, dim=0)
+                .cpu()
+                .numpy()
+            )
+            y_pred = (
+                torch.cat(y_pred_list, dim=0)
+                .cpu()
+                .numpy()
+            )
+            # [test_size, num_classes] = [test_size * conf_size, num_classes].mean(axis=1)
+
+            auc = roc_auc_score(y_true , y_pred)
+                    
+
+            agg_auc = auc
+            metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)
+
+    @staticmethod
+    def logging_outputs_can_be_summed(is_train) -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return is_train
+
+@register_loss("finetune_cross_entropy_pocket")
+class FinetuneCrossEntropyPocketLoss(FinetuneCrossEntropyLoss):
+    def __init__(self, task):
+        super().__init__(task)
+
+    def forward(self, model, sample, reduce=True):
+        """Compute the loss for the given sample.
+
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(
+            **sample["net_input"],
+            features_only=True,
+            classification_head_name=self.args.classification_head_name,
+        )
+        logit_output = net_output[0]
+        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
+        sample_size = sample["target"]["finetune_target"].size(0)
+        if not self.training:
+            probs = F.softmax(logit_output.float(), dim=-1).view(
+                -1, logit_output.size(-1)
+            )
+            logging_output = {
+                "loss": loss.data,
+                "prob": probs.data,
+                "target": sample["target"]["finetune_target"].view(-1).data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        else:
+            logging_output = {
+                "loss": loss.data,
+                "sample_size": sample_size,
+                "bsz": sample["target"]["finetune_target"].size(0),
+            }
+        return loss, sample_size, logging_output
+
+    @staticmethod
+    def reduce_metrics(logging_outputs, split="valid") -> None:
+        """Aggregate logging outputs from data parallel training."""
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+        # we divide by log(2) to convert the loss from base e to base 2
+        metrics.log_scalar(
+            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
+        )
+        if "valid" in split or "test" in split:
+            acc_sum = sum(
+                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
+                for log in logging_outputs
+            )
+            metrics.log_scalar(
+                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
+            )
+            preds = (
+                torch.cat(
+                    [log.get("prob").argmax(dim=-1) for log in logging_outputs], dim=0
+                )
+                .cpu()
+                .numpy()
+            )
+            targets = (
+                torch.cat([log.get("target", 0) for log in logging_outputs], dim=0)
+                .cpu()
+                .numpy()
+            )
+            metrics.log_scalar(f"{split}_pre", precision_score(targets, preds), round=3)
+            metrics.log_scalar(f"{split}_rec", recall_score(targets, preds), round=3)
+            metrics.log_scalar(
+                f"{split}_f1", f1_score(targets, preds), sample_size, round=3
+            )