Diff of /unimol/tasks/drugclip.py [000000] .. [b40915]

Switch to side-by-side view

--- a
+++ b/unimol/tasks/drugclip.py
@@ -0,0 +1,1007 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from IPython import embed as debug_embedded
+import logging
+import os
+from collections.abc import Iterable
+from sklearn.metrics import roc_auc_score
+from xmlrpc.client import Boolean
+import numpy as np
+import torch
+import pickle
+from tqdm import tqdm
+from unicore import checkpoint_utils
+import unicore
+from unicore.data import (AppendTokenDataset, Dictionary, EpochShuffleDataset,
+                          FromNumpyDataset, NestedDictionaryDataset,
+                          PrependTokenDataset, RawArrayDataset,LMDBDataset, RawLabelDataset,
+                          RightPadDataset, RightPadDataset2D, TokenizeDataset,SortDataset,data_utils)
+from unicore.tasks import UnicoreTask, register_task
+from unimol.data import (AffinityDataset, CroppingPocketDataset,
+                         CrossDistanceDataset, DistanceDataset,
+                         EdgeTypeDataset, KeyDataset, LengthDataset,
+                         NormalizeDataset, NormalizeDockingPoseDataset,
+                         PrependAndAppend2DDataset, RemoveHydrogenDataset,
+                         RemoveHydrogenPocketDataset, RightPadDatasetCoord,
+                         RightPadDatasetCross2D, TTADockingPoseDataset, AffinityTestDataset, AffinityValidDataset, AffinityMolDataset, AffinityPocketDataset, ResamplingDataset)
+#from skchem.metrics import bedroc_score
+from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
+from sklearn.metrics import roc_curve
+logger = logging.getLogger(__name__)
+
+
+def re_new(y_true, y_score, ratio):
+    fp = 0
+    tp = 0
+    p = sum(y_true)
+    n = len(y_true) - p
+    num = ratio*n
+    sort_index = np.argsort(y_score)[::-1]
+    for i in range(len(sort_index)):
+        index = sort_index[i]
+        if y_true[index] == 1:
+            tp += 1
+        else:
+            fp += 1
+            if fp>= num:
+                break
+    return (tp*n)/(p*fp)
+
+
+def calc_re(y_true, y_score, ratio_list):
+    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
+    #print(fpr, tpr)
+    res = {}
+    res2 = {}
+    total_active_compounds = sum(y_true)
+    total_compounds = len(y_true)
+
+    # for ratio in ratio_list:
+    #     for i, t in enumerate(fpr):
+    #         if t > ratio:
+    #             #print(fpr[i], tpr[i])
+    #             if fpr[i-1]==0:
+    #                 res[str(ratio)]=tpr[i]/fpr[i]
+    #             else:
+    #                 res[str(ratio)]=tpr[i-1]/fpr[i-1]
+    #             break
+    
+    for ratio in ratio_list:
+        res2[str(ratio)] = re_new(y_true, y_score, ratio)
+
+    #print(res)
+    #print(res2)
+    return res2
+
+def cal_metrics(y_true, y_score, alpha):
+    """
+    Calculate BEDROC score.
+
+    Parameters:
+    - y_true: true binary labels (0 or 1)
+    - y_score: predicted scores or probabilities
+    - alpha: parameter controlling the degree of early retrieval emphasis
+
+    Returns:
+    - BEDROC score
+    """
+    
+        # concate res_single and labels
+    scores = np.expand_dims(y_score, axis=1)
+    y_true = np.expand_dims(y_true, axis=1)
+    scores = np.concatenate((scores, y_true), axis=1)
+    # inverse sort scores based on first column
+    scores = scores[scores[:,0].argsort()[::-1]]
+    bedroc = CalcBEDROC(scores, 1, 80.5)
+    count = 0
+    # sort y_score, return index
+    index  = np.argsort(y_score)[::-1]
+    for i in range(int(len(index)*0.005)):
+        if y_true[index[i]] == 1:
+            count += 1
+    auc = CalcAUC(scores, 1)
+    ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.02, 0.05])
+    ef = {
+        "0.005": ef_list[0],
+        "0.01": ef_list[1],
+        "0.02": ef_list[2],
+        "0.05": ef_list[3]
+    }
+    re_list = calc_re(y_true, y_score, [0.005, 0.01, 0.02, 0.05])
+    return auc, bedroc, ef, re_list
+
+
+
+@register_task("drugclip")
+class DrugCLIP(UnicoreTask):
+    """Task for training transformer auto-encoder models."""
+
+    @staticmethod
+    def add_args(parser):
+        """Add task-specific arguments to the parser."""
+        parser.add_argument(
+            "data",
+            help="downstream data path",
+        )
+        parser.add_argument(
+            "--finetune-mol-model",
+            default=None,
+            type=str,
+            help="pretrained molecular model path",
+        )
+        parser.add_argument(
+            "--finetune-pocket-model",
+            default=None,
+            type=str,
+            help="pretrained pocket model path",
+        )
+        parser.add_argument(
+            "--dist-threshold",
+            type=float,
+            default=6.0,
+            help="threshold for the distance between the molecule and the pocket",
+        )
+        parser.add_argument(
+            "--max-pocket-atoms",
+            type=int,
+            default=256,
+            help="selected maximum number of atoms in a pocket",
+        )
+        parser.add_argument(
+            "--test-model",
+            default=False,
+            type=Boolean,
+            help="whether test model",
+        )
+        parser.add_argument("--reg", action="store_true", help="regression task")
+
+    def __init__(self, args, dictionary, pocket_dictionary):
+        super().__init__(args)
+        self.dictionary = dictionary
+        self.pocket_dictionary = pocket_dictionary
+        self.seed = args.seed
+        # add mask token
+        self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
+        self.pocket_mask_idx = pocket_dictionary.add_symbol("[MASK]", is_special=True)
+        self.mol_reps = None
+        self.keys = None
+
+    @classmethod
+    def setup_task(cls, args, **kwargs):
+        mol_dictionary = Dictionary.load(os.path.join(args.data, "dict_mol.txt"))
+        pocket_dictionary = Dictionary.load(os.path.join(args.data, "dict_pkt.txt"))
+        logger.info("ligand dictionary: {} types".format(len(mol_dictionary)))
+        logger.info("pocket dictionary: {} types".format(len(pocket_dictionary)))
+        return cls(args, mol_dictionary, pocket_dictionary)
+
+    def load_dataset(self, split, **kwargs):
+        """Load a given dataset split.
+        'smi','pocket','atoms','coordinates','pocket_atoms','pocket_coordinates'
+        Args:
+            split (str): name of the data scoure (e.g., bppp)
+        """
+        data_path = os.path.join(self.args.data, split + ".lmdb")
+        dataset = LMDBDataset(data_path)
+        if split.startswith("train"):
+            smi_dataset = KeyDataset(dataset, "smi")
+            poc_dataset = KeyDataset(dataset, "pocket")
+            
+            dataset = AffinityDataset(
+                dataset,
+                self.args.seed,
+                "atoms",
+                "coordinates",
+                "pocket_atoms",
+                "pocket_coordinates",
+                "label",
+                True,
+            )
+            tgt_dataset = KeyDataset(dataset, "affinity")
+            
+        else:
+            
+            dataset = AffinityDataset(
+                dataset,
+                self.args.seed,
+                "atoms",
+                "coordinates",
+                "pocket_atoms",
+                "pocket_coordinates",
+                "label",
+            )
+            tgt_dataset = KeyDataset(dataset, "affinity")
+            smi_dataset = KeyDataset(dataset, "smi")
+            poc_dataset = KeyDataset(dataset, "pocket")
+
+
+        def PrependAndAppend(dataset, pre_token, app_token):
+            dataset = PrependTokenDataset(dataset, pre_token)
+            return AppendTokenDataset(dataset, app_token)
+
+        dataset = RemoveHydrogenPocketDataset(
+            dataset,
+            "pocket_atoms",
+            "pocket_coordinates",
+            True,
+            True,
+        )
+        dataset = CroppingPocketDataset(
+            dataset,
+            self.seed,
+            "pocket_atoms",
+            "pocket_coordinates",
+            self.args.max_pocket_atoms,
+        )
+
+        dataset = RemoveHydrogenDataset(dataset, "atoms", "coordinates", True, True)
+
+
+        apo_dataset = NormalizeDataset(dataset, "coordinates")
+        apo_dataset = NormalizeDataset(apo_dataset, "pocket_coordinates")
+
+        src_dataset = KeyDataset(apo_dataset, "atoms")
+        mol_len_dataset = LengthDataset(src_dataset)
+        src_dataset = TokenizeDataset(
+            src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
+        )
+        coord_dataset = KeyDataset(apo_dataset, "coordinates")
+        src_dataset = PrependAndAppend(
+            src_dataset, self.dictionary.bos(), self.dictionary.eos()
+        )
+        edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
+        coord_dataset = FromNumpyDataset(coord_dataset)
+        distance_dataset = DistanceDataset(coord_dataset)
+        coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
+        distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0)
+
+        src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms")
+        pocket_len_dataset = LengthDataset(src_pocket_dataset)
+        src_pocket_dataset = TokenizeDataset(
+            src_pocket_dataset,
+            self.pocket_dictionary,
+            max_seq_len=self.args.max_seq_len,
+        )
+        coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates")
+        src_pocket_dataset = PrependAndAppend(
+            src_pocket_dataset,
+            self.pocket_dictionary.bos(),
+            self.pocket_dictionary.eos(),
+        )
+        pocket_edge_type = EdgeTypeDataset(
+            src_pocket_dataset, len(self.pocket_dictionary)
+        )
+        coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset)
+        distance_pocket_dataset = DistanceDataset(coord_pocket_dataset)
+        coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0)
+        distance_pocket_dataset = PrependAndAppend2DDataset(
+            distance_pocket_dataset, 0.0
+        )
+
+        nest_dataset = NestedDictionaryDataset(
+            {
+                "net_input": {
+                    "mol_src_tokens": RightPadDataset(
+                        src_dataset,
+                        pad_idx=self.dictionary.pad(),
+                    ),
+                    "mol_src_distance": RightPadDataset2D(
+                        distance_dataset,
+                        pad_idx=0,
+                    ),
+                    "mol_src_edge_type": RightPadDataset2D(
+                        edge_type,
+                        pad_idx=0,
+                    ),
+                    "pocket_src_tokens": RightPadDataset(
+                        src_pocket_dataset,
+                        pad_idx=self.pocket_dictionary.pad(),
+                    ),
+                    "pocket_src_distance": RightPadDataset2D(
+                        distance_pocket_dataset,
+                        pad_idx=0,
+                    ),
+                    "pocket_src_edge_type": RightPadDataset2D(
+                        pocket_edge_type,
+                        pad_idx=0,
+                    ),
+                    "pocket_src_coord": RightPadDatasetCoord(
+                        coord_pocket_dataset,
+                        pad_idx=0,
+                    ),
+                    "mol_len": RawArrayDataset(mol_len_dataset),
+                    "pocket_len": RawArrayDataset(pocket_len_dataset)
+                },
+                "target": {
+                    "finetune_target": RawLabelDataset(tgt_dataset),
+                },
+                "smi_name": RawArrayDataset(smi_dataset),
+                "pocket_name": RawArrayDataset(poc_dataset),
+            },
+        )
+        if split == "train":
+            with data_utils.numpy_seed(self.args.seed):
+                shuffle = np.random.permutation(len(src_dataset))
+
+            self.datasets[split] = SortDataset(
+                nest_dataset,
+                sort_order=[shuffle],
+            )
+            self.datasets[split] = ResamplingDataset(
+                self.datasets[split]
+            )
+        else:
+            self.datasets[split] = nest_dataset
+
+
+    
+
+    def load_mols_dataset(self, data_path,atoms,coords, **kwargs):
+ 
+        dataset = LMDBDataset(data_path)
+        label_dataset = KeyDataset(dataset, "label")
+        dataset = AffinityMolDataset(
+            dataset,
+            self.args.seed,
+            atoms,
+            coords,
+            False,
+        )
+        
+        smi_dataset = KeyDataset(dataset, "smi")
+
+        def PrependAndAppend(dataset, pre_token, app_token):
+            dataset = PrependTokenDataset(dataset, pre_token)
+            return AppendTokenDataset(dataset, app_token)
+
+
+
+        dataset = RemoveHydrogenDataset(dataset, "atoms", "coordinates", True, True)
+
+
+        apo_dataset = NormalizeDataset(dataset, "coordinates")
+
+        src_dataset = KeyDataset(apo_dataset, "atoms")
+        len_dataset = LengthDataset(src_dataset)
+        src_dataset = TokenizeDataset(
+            src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
+        )
+        coord_dataset = KeyDataset(apo_dataset, "coordinates")
+        src_dataset = PrependAndAppend(
+            src_dataset, self.dictionary.bos(), self.dictionary.eos()
+        )
+        edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
+        coord_dataset = FromNumpyDataset(coord_dataset)
+        distance_dataset = DistanceDataset(coord_dataset)
+        coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
+        distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0)
+
+
+        nest_dataset = NestedDictionaryDataset(
+            {
+                "net_input": {
+                    "mol_src_tokens": RightPadDataset(
+                        src_dataset,
+                        pad_idx=self.dictionary.pad(),
+                    ),
+                    "mol_src_distance": RightPadDataset2D(
+                        distance_dataset,
+                        pad_idx=0,
+                    ),
+                    "mol_src_edge_type": RightPadDataset2D(
+                        edge_type,
+                        pad_idx=0,
+                    ),
+                },
+                "smi_name": RawArrayDataset(smi_dataset),
+                "target":  RawArrayDataset(label_dataset),
+                "mol_len": RawArrayDataset(len_dataset),
+            },
+        )
+        return nest_dataset
+    
+
+    def load_retrieval_mols_dataset(self, data_path,atoms,coords, **kwargs):
+ 
+        dataset = LMDBDataset(data_path)
+        dataset = AffinityMolDataset(
+            dataset,
+            self.args.seed,
+            atoms,
+            coords,
+            False,
+        )
+        
+        smi_dataset = KeyDataset(dataset, "smi")
+
+        def PrependAndAppend(dataset, pre_token, app_token):
+            dataset = PrependTokenDataset(dataset, pre_token)
+            return AppendTokenDataset(dataset, app_token)
+
+
+
+        dataset = RemoveHydrogenDataset(dataset, "atoms", "coordinates", True, True)
+
+
+        apo_dataset = NormalizeDataset(dataset, "coordinates")
+
+        src_dataset = KeyDataset(apo_dataset, "atoms")
+        len_dataset = LengthDataset(src_dataset)
+        src_dataset = TokenizeDataset(
+            src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
+        )
+        coord_dataset = KeyDataset(apo_dataset, "coordinates")
+        src_dataset = PrependAndAppend(
+            src_dataset, self.dictionary.bos(), self.dictionary.eos()
+        )
+        edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
+        coord_dataset = FromNumpyDataset(coord_dataset)
+        distance_dataset = DistanceDataset(coord_dataset)
+        coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
+        distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0)
+
+
+        nest_dataset = NestedDictionaryDataset(
+            {
+                "net_input": {
+                    "mol_src_tokens": RightPadDataset(
+                        src_dataset,
+                        pad_idx=self.dictionary.pad(),
+                    ),
+                    "mol_src_distance": RightPadDataset2D(
+                        distance_dataset,
+                        pad_idx=0,
+                    ),
+                    "mol_src_edge_type": RightPadDataset2D(
+                        edge_type,
+                        pad_idx=0,
+                    ),
+                },
+                "smi_name": RawArrayDataset(smi_dataset),
+                "mol_len": RawArrayDataset(len_dataset),
+            },
+        )
+        return nest_dataset
+
+    def load_pockets_dataset(self, data_path, **kwargs):
+
+        dataset = LMDBDataset(data_path)
+ 
+        dataset = AffinityPocketDataset(
+            dataset,
+            self.args.seed,
+            "pocket_atoms",
+            "pocket_coordinates",
+            False,
+            "pocket"
+        )
+        poc_dataset = KeyDataset(dataset, "pocket")
+
+        def PrependAndAppend(dataset, pre_token, app_token):
+            dataset = PrependTokenDataset(dataset, pre_token)
+            return AppendTokenDataset(dataset, app_token)
+
+        dataset = RemoveHydrogenPocketDataset(
+            dataset,
+            "pocket_atoms",
+            "pocket_coordinates",
+            True,
+            True,
+        )
+        dataset = CroppingPocketDataset(
+            dataset,
+            self.seed,
+            "pocket_atoms",
+            "pocket_coordinates",
+            self.args.max_pocket_atoms,
+        )
+
+
+
+
+        apo_dataset = NormalizeDataset(dataset, "pocket_coordinates")
+
+
+
+        src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms")
+        len_dataset = LengthDataset(src_pocket_dataset)
+        src_pocket_dataset = TokenizeDataset(
+            src_pocket_dataset,
+            self.pocket_dictionary,
+            max_seq_len=self.args.max_seq_len,
+        )
+        coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates")
+        src_pocket_dataset = PrependAndAppend(
+            src_pocket_dataset,
+            self.pocket_dictionary.bos(),
+            self.pocket_dictionary.eos(),
+        )
+        pocket_edge_type = EdgeTypeDataset(
+            src_pocket_dataset, len(self.pocket_dictionary)
+        )
+        coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset)
+        distance_pocket_dataset = DistanceDataset(coord_pocket_dataset)
+        coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0)
+        distance_pocket_dataset = PrependAndAppend2DDataset(
+            distance_pocket_dataset, 0.0
+        )
+
+        nest_dataset = NestedDictionaryDataset(
+            {
+                "net_input": {
+                    "pocket_src_tokens": RightPadDataset(
+                        src_pocket_dataset,
+                        pad_idx=self.pocket_dictionary.pad(),
+                    ),
+                    "pocket_src_distance": RightPadDataset2D(
+                        distance_pocket_dataset,
+                        pad_idx=0,
+                    ),
+                    "pocket_src_edge_type": RightPadDataset2D(
+                        pocket_edge_type,
+                        pad_idx=0,
+                    ),
+                    "pocket_src_coord": RightPadDatasetCoord(
+                        coord_pocket_dataset,
+                        pad_idx=0,
+                    ),
+                },
+                "pocket_name": RawArrayDataset(poc_dataset),
+                "pocket_len": RawArrayDataset(len_dataset),
+            },
+        )
+        return nest_dataset
+
+    
+
+    def build_model(self, args):
+        from unicore import models
+
+        model = models.build_model(args, self)
+        
+        if args.finetune_mol_model is not None:
+            print("load pretrain model weight from...", args.finetune_mol_model)
+            state = checkpoint_utils.load_checkpoint_to_cpu(
+                args.finetune_mol_model,
+            )
+            model.mol_model.load_state_dict(state["model"], strict=False)
+            
+        if args.finetune_pocket_model is not None:
+            print("load pretrain model weight from...", args.finetune_pocket_model)
+            state = checkpoint_utils.load_checkpoint_to_cpu(
+                args.finetune_pocket_model,
+            )
+            model.pocket_model.load_state_dict(state["model"], strict=False)
+
+        return model
+
+    def train_step(
+        self, sample, model, loss, optimizer, update_num, ignore_grad=False
+    ):
+        """
+        Do forward and backward, and return the loss as computed by *loss*
+        for the given *model* and *sample*.
+
+        Args:
+            sample (dict): the mini-batch. The format is defined by the
+                :class:`~unicore.data.UnicoreDataset`.
+            model (~unicore.models.BaseUnicoreModel): the model
+            loss (~unicore.losses.UnicoreLoss): the loss
+            optimizer (~unicore.optim.UnicoreOptimizer): the optimizer
+            update_num (int): the current update
+            ignore_grad (bool): multiply loss by 0 if this is set to True
+
+        Returns:
+            tuple:
+                - the loss
+                - the sample size, which is used as the denominator for the
+                  gradient
+                - logging outputs to display while training
+        """
+
+        model.train()
+        model.set_num_updates(update_num)
+        with torch.autograd.profiler.record_function("forward"):
+            loss, sample_size, logging_output = loss(model, sample)
+        if ignore_grad:
+            loss *= 0
+        with torch.autograd.profiler.record_function("backward"):
+            optimizer.backward(loss)
+        return loss, sample_size, logging_output
+    
+    def valid_step(self, sample, model, loss, test=False):
+        model.eval()
+        with torch.no_grad():
+            loss, sample_size, logging_output = loss(model, sample)
+        return loss, sample_size, logging_output
+
+
+    def test_pcba_target(self, name, model, **kwargs):
+        """Encode a dataset with the molecule encoder."""
+
+        #names = "PPARG"
+        data_path = "./data/lit_pcba/" + name + "/mols.lmdb"
+        mol_dataset = self.load_mols_dataset(data_path, "atoms", "coordinates")
+        num_data = len(mol_dataset)
+        bsz=64
+        #print(num_data//bsz)
+        mol_reps = []
+        mol_names = []
+        labels = []
+        
+        # generate mol data
+        
+        mol_data = torch.utils.data.DataLoader(mol_dataset, batch_size=bsz, collate_fn=mol_dataset.collater)
+        for _, sample in enumerate(tqdm(mol_data)):
+            sample = unicore.utils.move_to_cuda(sample)
+            dist = sample["net_input"]["mol_src_distance"]
+            et = sample["net_input"]["mol_src_edge_type"]
+            st = sample["net_input"]["mol_src_tokens"]
+            mol_padding_mask = st.eq(model.mol_model.padding_idx)
+            mol_x = model.mol_model.embed_tokens(st)
+            
+            n_node = dist.size(-1)
+            gbf_feature = model.mol_model.gbf(dist, et)
+
+            gbf_result = model.mol_model.gbf_proj(gbf_feature)
+            graph_attn_bias = gbf_result
+            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
+            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
+            mol_outputs = model.mol_model.encoder(
+                mol_x, padding_mask=mol_padding_mask, attn_mask=graph_attn_bias
+            )
+            mol_encoder_rep = mol_outputs[0][:,0,:]
+            mol_emb = model.mol_project(mol_encoder_rep)
+            mol_emb = mol_emb / mol_emb.norm(dim=1, keepdim=True)
+            mol_emb = mol_emb.detach().cpu().numpy()
+            mol_reps.append(mol_emb)
+            mol_names.extend(sample["smi_name"])
+            labels.extend(sample["target"].detach().cpu().numpy())
+        mol_reps = np.concatenate(mol_reps, axis=0)
+        labels = np.array(labels, dtype=np.int32)
+        # generate pocket data
+        data_path = "./data/lit_pcba/" + name + "/pockets.lmdb"
+        pocket_dataset = self.load_pockets_dataset(data_path)
+        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=bsz, collate_fn=pocket_dataset.collater)
+        pocket_reps = []
+
+        for _, sample in enumerate(tqdm(pocket_data)):
+            sample = unicore.utils.move_to_cuda(sample)
+            dist = sample["net_input"]["pocket_src_distance"]
+            et = sample["net_input"]["pocket_src_edge_type"]
+            st = sample["net_input"]["pocket_src_tokens"]
+            pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
+            pocket_x = model.pocket_model.embed_tokens(st)
+            n_node = dist.size(-1)
+            gbf_feature = model.pocket_model.gbf(dist, et)
+            gbf_result = model.pocket_model.gbf_proj(gbf_feature)
+            graph_attn_bias = gbf_result
+            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
+            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
+            pocket_outputs = model.pocket_model.encoder(
+                pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
+            )
+            pocket_encoder_rep = pocket_outputs[0][:,0,:]
+            pocket_emb = model.pocket_project(pocket_encoder_rep)
+            pocket_emb = pocket_emb / pocket_emb.norm(dim=1, keepdim=True)
+            pocket_emb = pocket_emb.detach().cpu().numpy()
+            pocket_names = sample["pocket_name"]
+            pocket_reps.append(pocket_emb)
+        pocket_reps = np.concatenate(pocket_reps, axis=0)
+
+        res = pocket_reps @ mol_reps.T
+        res_single = res.max(axis=0)
+        auc, bedroc, ef_list, re_list = cal_metrics(labels, res_single, 80.5)
+
+        return auc, bedroc, ef_list, re_list
+    
+    
+    
+
+    def test_pcba(self, model, **kwargs):
+        #ckpt_date = self.args.finetune_from_model.split("/")[-2]
+        #save_name = "/home/gaobowen/DrugClip/test_results/pcba/" + ckpt_date + ".txt"
+        save_name = ""
+        
+        targets = os.listdir("./data/lit_pcba/")
+
+        #print(targets)
+        auc_list = []
+        ef_list = []
+        bedroc_list = []
+
+        re_list = {
+            "0.005": [],
+            "0.01": [],
+            "0.02": [],
+            "0.05": []
+        }
+        ef_list = {
+            "0.005": [],
+            "0.01": [],
+            "0.02": [],
+            "0.05": []
+        }
+        for target in targets:
+            auc, bedroc, ef, re = self.test_pcba_target(target, model)
+            auc_list.append(auc)
+            bedroc_list.append(bedroc)
+            for key in ef:
+                ef_list[key].append(ef[key])
+            # print("re", re)
+            # print("ef", ef)
+            for key in re:
+                re_list[key].append(re[key])
+        print(auc_list)
+        print(ef_list)
+        print("auc 25%", np.percentile(auc_list, 25))
+        print("auc 50%", np.percentile(auc_list, 50))
+        print("auc 75%", np.percentile(auc_list, 75))
+        print("auc mean", np.mean(auc_list))
+        print("bedroc 25%", np.percentile(bedroc_list, 25))
+        print("bedroc 50%", np.percentile(bedroc_list, 50))
+        print("bedroc 75%", np.percentile(bedroc_list, 75))
+        print("bedroc mean", np.mean(bedroc_list))
+        #print(np.median(auc_list))
+        #print(np.median(ef_list))
+        for key in ef_list:
+            print("ef", key, "25%", np.percentile(ef_list[key], 25))
+            print("ef",key, "50%", np.percentile(ef_list[key], 50))
+            print("ef",key, "75%", np.percentile(ef_list[key], 75))
+            print("ef",key, "mean", np.mean(ef_list[key]))
+        for key in re_list:
+            print("re",key, "25%", np.percentile(re_list[key], 25))
+            print("re",key, "50%", np.percentile(re_list[key], 50))
+            print("re",key, "75%", np.percentile(re_list[key], 75))
+            print("re",key, "mean", np.mean(re_list[key]))
+
+        return 
+    
+    def test_dude_target(self, target, model, **kwargs):
+
+        data_path = "./data/DUD-E/raw/all/" + target + "/mols.lmdb"
+        mol_dataset = self.load_mols_dataset(data_path, "atoms", "coordinates")
+        num_data = len(mol_dataset)
+        bsz=64
+        print(num_data//bsz)
+        mol_reps = []
+        mol_names = []
+        labels = []
+        
+        # generate mol data
+        
+        mol_data = torch.utils.data.DataLoader(mol_dataset, batch_size=bsz, collate_fn=mol_dataset.collater)
+        for _, sample in enumerate(tqdm(mol_data)):
+            sample = unicore.utils.move_to_cuda(sample)
+            dist = sample["net_input"]["mol_src_distance"]
+            et = sample["net_input"]["mol_src_edge_type"]
+            st = sample["net_input"]["mol_src_tokens"]
+            mol_padding_mask = st.eq(model.mol_model.padding_idx)
+            mol_x = model.mol_model.embed_tokens(st)
+            n_node = dist.size(-1)
+            gbf_feature = model.mol_model.gbf(dist, et)
+            gbf_result = model.mol_model.gbf_proj(gbf_feature)
+            graph_attn_bias = gbf_result
+            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
+            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
+            mol_outputs = model.mol_model.encoder(
+                mol_x, padding_mask=mol_padding_mask, attn_mask=graph_attn_bias
+            )
+            mol_encoder_rep = mol_outputs[0][:,0,:]
+            mol_emb = mol_encoder_rep
+            mol_emb = model.mol_project(mol_encoder_rep)
+            mol_emb = mol_emb / mol_emb.norm(dim=-1, keepdim=True)
+            #print(mol_emb.dtype)
+            mol_emb = mol_emb.detach().cpu().numpy()
+            #print(mol_emb.dtype)
+            mol_reps.append(mol_emb)
+            mol_names.extend(sample["smi_name"])
+            labels.extend(sample["target"].detach().cpu().numpy())
+        mol_reps = np.concatenate(mol_reps, axis=0)
+        labels = np.array(labels, dtype=np.int32)
+        # generate pocket data
+        data_path = "./data/DUD-E/raw/all/" + target + "/pocket.lmdb"
+        pocket_dataset = self.load_pockets_dataset(data_path)
+        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=bsz, collate_fn=pocket_dataset.collater)
+        pocket_reps = []
+
+        for _, sample in enumerate(tqdm(pocket_data)):
+            sample = unicore.utils.move_to_cuda(sample)
+            dist = sample["net_input"]["pocket_src_distance"]
+            et = sample["net_input"]["pocket_src_edge_type"]
+            st = sample["net_input"]["pocket_src_tokens"]
+            pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
+            pocket_x = model.pocket_model.embed_tokens(st)
+            n_node = dist.size(-1)
+            gbf_feature = model.pocket_model.gbf(dist, et)
+            gbf_result = model.pocket_model.gbf_proj(gbf_feature)
+            graph_attn_bias = gbf_result
+            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
+            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
+            pocket_outputs = model.pocket_model.encoder(
+                pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
+            )
+            pocket_encoder_rep = pocket_outputs[0][:,0,:]
+            #pocket_emb = pocket_encoder_rep
+            pocket_emb = model.pocket_project(pocket_encoder_rep)
+            pocket_emb = pocket_emb / pocket_emb.norm(dim=-1, keepdim=True)
+            pocket_emb = pocket_emb.detach().cpu().numpy()
+            pocket_reps.append(pocket_emb)
+        pocket_reps = np.concatenate(pocket_reps, axis=0)
+        print(pocket_reps.shape)
+        res = pocket_reps @ mol_reps.T
+
+        res_single = res.max(axis=0)
+
+        auc, bedroc, ef_list, re_list = cal_metrics(labels, res_single, 80.5)
+        
+        
+        print(target)
+
+        print(np.sum(labels), len(labels)-np.sum(labels))
+
+        return auc, bedroc, ef_list, re_list, res_single, labels
+
+    def test_dude(self, model, **kwargs):
+
+
+        targets = os.listdir("./data/DUD-E/raw/all/")
+        auc_list = []
+        bedroc_list = []
+        ef_list = []
+        res_list= []
+        labels_list = []
+        re_list = {
+            "0.005": [],
+            "0.01": [],
+            "0.02": [],
+            "0.05": [],
+        }
+        ef_list = {
+            "0.005": [],
+            "0.01": [],
+            "0.02": [],
+            "0.05": [],
+        }
+        for i,target in enumerate(targets):
+            auc, bedroc, ef, re, res_single, labels = self.test_dude_target(target, model)
+            auc_list.append(auc)
+            bedroc_list.append(bedroc)
+            for key in ef:
+                ef_list[key].append(ef[key])
+            for key in re_list:
+                re_list[key].append(re[key])
+            res_list.append(res_single)
+            labels_list.append(labels)
+        res = np.concatenate(res_list, axis=0)
+        labels = np.concatenate(labels_list, axis=0)
+        print("auc mean", np.mean(auc_list))
+        print("bedroc mean", np.mean(bedroc_list))
+
+        for key in ef_list:
+            print("ef", key, "mean", np.mean(ef_list[key]))
+
+        for key in re_list:
+            print("re", key, "mean",  np.mean(re_list[key]))
+
+        # save printed results 
+        
+        
+        return
+    
+    
+    
+    
+    
+    def encode_mols_once(self, model, data_path, emb_dir, atoms, coords, **kwargs):
+        
+        # cache path is embdir/data_path.pkl
+
+        cache_path = os.path.join(emb_dir, data_path.split("/")[-1] + ".pkl")
+
+        if os.path.exists(cache_path):
+            with open(cache_path, "rb") as f:
+                mol_reps, mol_names = pickle.load(f)
+            return mol_reps, mol_names
+
+        mol_dataset = self.load_retrieval_mols_dataset(data_path,atoms,coords)
+        mol_reps = []
+        mol_names = []
+        bsz=32
+        mol_data = torch.utils.data.DataLoader(mol_dataset, batch_size=bsz, collate_fn=mol_dataset.collater)
+        for _, sample in enumerate(tqdm(mol_data)):
+            sample = unicore.utils.move_to_cuda(sample)
+            dist = sample["net_input"]["mol_src_distance"]
+            et = sample["net_input"]["mol_src_edge_type"]
+            st = sample["net_input"]["mol_src_tokens"]
+            mol_padding_mask = st.eq(model.mol_model.padding_idx)
+            mol_x = model.mol_model.embed_tokens(st)
+            n_node = dist.size(-1)
+            gbf_feature = model.mol_model.gbf(dist, et)
+            gbf_result = model.mol_model.gbf_proj(gbf_feature)
+            graph_attn_bias = gbf_result
+            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
+            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
+            mol_outputs = model.mol_model.encoder(
+                mol_x, padding_mask=mol_padding_mask, attn_mask=graph_attn_bias
+            )
+            mol_encoder_rep = mol_outputs[0][:,0,:]
+            mol_emb = model.mol_project(mol_encoder_rep)
+            mol_emb = mol_emb / mol_emb.norm(dim=-1, keepdim=True)
+            mol_emb = mol_emb.detach().cpu().numpy()
+            mol_reps.append(mol_emb)
+            mol_names.extend(sample["smi_name"])
+
+        mol_reps = np.concatenate(mol_reps, axis=0)
+
+        # save the results
+        
+        with open(cache_path, "wb") as f:
+            pickle.dump([mol_reps, mol_names], f)
+
+        return mol_reps, mol_names
+    
+    def retrieve_mols(self, model, mol_path, pocket_path, emb_dir, k, **kwargs):
+ 
+        os.makedirs(emb_dir, exist_ok=True)        
+        mol_reps, mol_names = self.encode_mols_once(model, mol_path, emb_dir,  "atoms", "coordinates")
+        
+        pocket_dataset = self.load_pockets_dataset(pocket_path)
+        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=16, collate_fn=pocket_dataset.collater)
+        pocket_reps = []
+        pocket_names = []
+        for _, sample in enumerate(tqdm(pocket_data)):
+            sample = unicore.utils.move_to_cuda(sample)
+            dist = sample["net_input"]["pocket_src_distance"]
+            et = sample["net_input"]["pocket_src_edge_type"]
+            st = sample["net_input"]["pocket_src_tokens"]
+            pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
+            pocket_x = model.pocket_model.embed_tokens(st)
+            n_node = dist.size(-1)
+            gbf_feature = model.pocket_model.gbf(dist, et)
+            gbf_result = model.pocket_model.gbf_proj(gbf_feature)
+            graph_attn_bias = gbf_result
+            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
+            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
+            pocket_outputs = model.pocket_model.encoder(
+                pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
+            )
+            pocket_encoder_rep = pocket_outputs[0][:,0,:]
+            pocket_emb = model.pocket_project(pocket_encoder_rep)
+            pocket_emb = pocket_emb / pocket_emb.norm(dim=-1, keepdim=True)
+            pocket_emb = pocket_emb.detach().cpu().numpy()
+            pocket_reps.append(pocket_emb)
+            pocket_names.extend(sample["pocket_name"])
+        pocket_reps = np.concatenate(pocket_reps, axis=0)
+        
+        res = pocket_reps @ mol_reps.T
+        res = res.max(axis=0)
+
+
+        # get top k results
+
+        
+        top_k = np.argsort(res)[::-1][:k]
+
+        # return names and scores
+        
+        return [mol_names[i] for i in top_k], res[top_k]
+
+
+        
+
+        
+         
+
+
+    
+
+    
+
+        
+            
+         
+
+        
+    
+