Diff of /unimol/tasks/drugclip.py [000000] .. [b40915]

Switch to unified view

a b/unimol/tasks/drugclip.py
1
# Copyright (c) DP Technology.
2
# This source code is licensed under the MIT license found in the
3
# LICENSE file in the root directory of this source tree.
4
from IPython import embed as debug_embedded
5
import logging
6
import os
7
from collections.abc import Iterable
8
from sklearn.metrics import roc_auc_score
9
from xmlrpc.client import Boolean
10
import numpy as np
11
import torch
12
import pickle
13
from tqdm import tqdm
14
from unicore import checkpoint_utils
15
import unicore
16
from unicore.data import (AppendTokenDataset, Dictionary, EpochShuffleDataset,
17
                          FromNumpyDataset, NestedDictionaryDataset,
18
                          PrependTokenDataset, RawArrayDataset,LMDBDataset, RawLabelDataset,
19
                          RightPadDataset, RightPadDataset2D, TokenizeDataset,SortDataset,data_utils)
20
from unicore.tasks import UnicoreTask, register_task
21
from unimol.data import (AffinityDataset, CroppingPocketDataset,
22
                         CrossDistanceDataset, DistanceDataset,
23
                         EdgeTypeDataset, KeyDataset, LengthDataset,
24
                         NormalizeDataset, NormalizeDockingPoseDataset,
25
                         PrependAndAppend2DDataset, RemoveHydrogenDataset,
26
                         RemoveHydrogenPocketDataset, RightPadDatasetCoord,
27
                         RightPadDatasetCross2D, TTADockingPoseDataset, AffinityTestDataset, AffinityValidDataset, AffinityMolDataset, AffinityPocketDataset, ResamplingDataset)
28
#from skchem.metrics import bedroc_score
29
from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
30
from sklearn.metrics import roc_curve
31
logger = logging.getLogger(__name__)
32
33
34
def re_new(y_true, y_score, ratio):
35
    fp = 0
36
    tp = 0
37
    p = sum(y_true)
38
    n = len(y_true) - p
39
    num = ratio*n
40
    sort_index = np.argsort(y_score)[::-1]
41
    for i in range(len(sort_index)):
42
        index = sort_index[i]
43
        if y_true[index] == 1:
44
            tp += 1
45
        else:
46
            fp += 1
47
            if fp>= num:
48
                break
49
    return (tp*n)/(p*fp)
50
51
52
def calc_re(y_true, y_score, ratio_list):
53
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
54
    #print(fpr, tpr)
55
    res = {}
56
    res2 = {}
57
    total_active_compounds = sum(y_true)
58
    total_compounds = len(y_true)
59
60
    # for ratio in ratio_list:
61
    #     for i, t in enumerate(fpr):
62
    #         if t > ratio:
63
    #             #print(fpr[i], tpr[i])
64
    #             if fpr[i-1]==0:
65
    #                 res[str(ratio)]=tpr[i]/fpr[i]
66
    #             else:
67
    #                 res[str(ratio)]=tpr[i-1]/fpr[i-1]
68
    #             break
69
    
70
    for ratio in ratio_list:
71
        res2[str(ratio)] = re_new(y_true, y_score, ratio)
72
73
    #print(res)
74
    #print(res2)
75
    return res2
76
77
def cal_metrics(y_true, y_score, alpha):
78
    """
79
    Calculate BEDROC score.
80
81
    Parameters:
82
    - y_true: true binary labels (0 or 1)
83
    - y_score: predicted scores or probabilities
84
    - alpha: parameter controlling the degree of early retrieval emphasis
85
86
    Returns:
87
    - BEDROC score
88
    """
89
    
90
        # concate res_single and labels
91
    scores = np.expand_dims(y_score, axis=1)
92
    y_true = np.expand_dims(y_true, axis=1)
93
    scores = np.concatenate((scores, y_true), axis=1)
94
    # inverse sort scores based on first column
95
    scores = scores[scores[:,0].argsort()[::-1]]
96
    bedroc = CalcBEDROC(scores, 1, 80.5)
97
    count = 0
98
    # sort y_score, return index
99
    index  = np.argsort(y_score)[::-1]
100
    for i in range(int(len(index)*0.005)):
101
        if y_true[index[i]] == 1:
102
            count += 1
103
    auc = CalcAUC(scores, 1)
104
    ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.02, 0.05])
105
    ef = {
106
        "0.005": ef_list[0],
107
        "0.01": ef_list[1],
108
        "0.02": ef_list[2],
109
        "0.05": ef_list[3]
110
    }
111
    re_list = calc_re(y_true, y_score, [0.005, 0.01, 0.02, 0.05])
112
    return auc, bedroc, ef, re_list
113
114
115
116
@register_task("drugclip")
117
class DrugCLIP(UnicoreTask):
118
    """Task for training transformer auto-encoder models."""
119
120
    @staticmethod
121
    def add_args(parser):
122
        """Add task-specific arguments to the parser."""
123
        parser.add_argument(
124
            "data",
125
            help="downstream data path",
126
        )
127
        parser.add_argument(
128
            "--finetune-mol-model",
129
            default=None,
130
            type=str,
131
            help="pretrained molecular model path",
132
        )
133
        parser.add_argument(
134
            "--finetune-pocket-model",
135
            default=None,
136
            type=str,
137
            help="pretrained pocket model path",
138
        )
139
        parser.add_argument(
140
            "--dist-threshold",
141
            type=float,
142
            default=6.0,
143
            help="threshold for the distance between the molecule and the pocket",
144
        )
145
        parser.add_argument(
146
            "--max-pocket-atoms",
147
            type=int,
148
            default=256,
149
            help="selected maximum number of atoms in a pocket",
150
        )
151
        parser.add_argument(
152
            "--test-model",
153
            default=False,
154
            type=Boolean,
155
            help="whether test model",
156
        )
157
        parser.add_argument("--reg", action="store_true", help="regression task")
158
159
    def __init__(self, args, dictionary, pocket_dictionary):
160
        super().__init__(args)
161
        self.dictionary = dictionary
162
        self.pocket_dictionary = pocket_dictionary
163
        self.seed = args.seed
164
        # add mask token
165
        self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
166
        self.pocket_mask_idx = pocket_dictionary.add_symbol("[MASK]", is_special=True)
167
        self.mol_reps = None
168
        self.keys = None
169
170
    @classmethod
171
    def setup_task(cls, args, **kwargs):
172
        mol_dictionary = Dictionary.load(os.path.join(args.data, "dict_mol.txt"))
173
        pocket_dictionary = Dictionary.load(os.path.join(args.data, "dict_pkt.txt"))
174
        logger.info("ligand dictionary: {} types".format(len(mol_dictionary)))
175
        logger.info("pocket dictionary: {} types".format(len(pocket_dictionary)))
176
        return cls(args, mol_dictionary, pocket_dictionary)
177
178
    def load_dataset(self, split, **kwargs):
179
        """Load a given dataset split.
180
        'smi','pocket','atoms','coordinates','pocket_atoms','pocket_coordinates'
181
        Args:
182
            split (str): name of the data scoure (e.g., bppp)
183
        """
184
        data_path = os.path.join(self.args.data, split + ".lmdb")
185
        dataset = LMDBDataset(data_path)
186
        if split.startswith("train"):
187
            smi_dataset = KeyDataset(dataset, "smi")
188
            poc_dataset = KeyDataset(dataset, "pocket")
189
            
190
            dataset = AffinityDataset(
191
                dataset,
192
                self.args.seed,
193
                "atoms",
194
                "coordinates",
195
                "pocket_atoms",
196
                "pocket_coordinates",
197
                "label",
198
                True,
199
            )
200
            tgt_dataset = KeyDataset(dataset, "affinity")
201
            
202
        else:
203
            
204
            dataset = AffinityDataset(
205
                dataset,
206
                self.args.seed,
207
                "atoms",
208
                "coordinates",
209
                "pocket_atoms",
210
                "pocket_coordinates",
211
                "label",
212
            )
213
            tgt_dataset = KeyDataset(dataset, "affinity")
214
            smi_dataset = KeyDataset(dataset, "smi")
215
            poc_dataset = KeyDataset(dataset, "pocket")
216
217
218
        def PrependAndAppend(dataset, pre_token, app_token):
219
            dataset = PrependTokenDataset(dataset, pre_token)
220
            return AppendTokenDataset(dataset, app_token)
221
222
        dataset = RemoveHydrogenPocketDataset(
223
            dataset,
224
            "pocket_atoms",
225
            "pocket_coordinates",
226
            True,
227
            True,
228
        )
229
        dataset = CroppingPocketDataset(
230
            dataset,
231
            self.seed,
232
            "pocket_atoms",
233
            "pocket_coordinates",
234
            self.args.max_pocket_atoms,
235
        )
236
237
        dataset = RemoveHydrogenDataset(dataset, "atoms", "coordinates", True, True)
238
239
240
        apo_dataset = NormalizeDataset(dataset, "coordinates")
241
        apo_dataset = NormalizeDataset(apo_dataset, "pocket_coordinates")
242
243
        src_dataset = KeyDataset(apo_dataset, "atoms")
244
        mol_len_dataset = LengthDataset(src_dataset)
245
        src_dataset = TokenizeDataset(
246
            src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
247
        )
248
        coord_dataset = KeyDataset(apo_dataset, "coordinates")
249
        src_dataset = PrependAndAppend(
250
            src_dataset, self.dictionary.bos(), self.dictionary.eos()
251
        )
252
        edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
253
        coord_dataset = FromNumpyDataset(coord_dataset)
254
        distance_dataset = DistanceDataset(coord_dataset)
255
        coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
256
        distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0)
257
258
        src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms")
259
        pocket_len_dataset = LengthDataset(src_pocket_dataset)
260
        src_pocket_dataset = TokenizeDataset(
261
            src_pocket_dataset,
262
            self.pocket_dictionary,
263
            max_seq_len=self.args.max_seq_len,
264
        )
265
        coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates")
266
        src_pocket_dataset = PrependAndAppend(
267
            src_pocket_dataset,
268
            self.pocket_dictionary.bos(),
269
            self.pocket_dictionary.eos(),
270
        )
271
        pocket_edge_type = EdgeTypeDataset(
272
            src_pocket_dataset, len(self.pocket_dictionary)
273
        )
274
        coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset)
275
        distance_pocket_dataset = DistanceDataset(coord_pocket_dataset)
276
        coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0)
277
        distance_pocket_dataset = PrependAndAppend2DDataset(
278
            distance_pocket_dataset, 0.0
279
        )
280
281
        nest_dataset = NestedDictionaryDataset(
282
            {
283
                "net_input": {
284
                    "mol_src_tokens": RightPadDataset(
285
                        src_dataset,
286
                        pad_idx=self.dictionary.pad(),
287
                    ),
288
                    "mol_src_distance": RightPadDataset2D(
289
                        distance_dataset,
290
                        pad_idx=0,
291
                    ),
292
                    "mol_src_edge_type": RightPadDataset2D(
293
                        edge_type,
294
                        pad_idx=0,
295
                    ),
296
                    "pocket_src_tokens": RightPadDataset(
297
                        src_pocket_dataset,
298
                        pad_idx=self.pocket_dictionary.pad(),
299
                    ),
300
                    "pocket_src_distance": RightPadDataset2D(
301
                        distance_pocket_dataset,
302
                        pad_idx=0,
303
                    ),
304
                    "pocket_src_edge_type": RightPadDataset2D(
305
                        pocket_edge_type,
306
                        pad_idx=0,
307
                    ),
308
                    "pocket_src_coord": RightPadDatasetCoord(
309
                        coord_pocket_dataset,
310
                        pad_idx=0,
311
                    ),
312
                    "mol_len": RawArrayDataset(mol_len_dataset),
313
                    "pocket_len": RawArrayDataset(pocket_len_dataset)
314
                },
315
                "target": {
316
                    "finetune_target": RawLabelDataset(tgt_dataset),
317
                },
318
                "smi_name": RawArrayDataset(smi_dataset),
319
                "pocket_name": RawArrayDataset(poc_dataset),
320
            },
321
        )
322
        if split == "train":
323
            with data_utils.numpy_seed(self.args.seed):
324
                shuffle = np.random.permutation(len(src_dataset))
325
326
            self.datasets[split] = SortDataset(
327
                nest_dataset,
328
                sort_order=[shuffle],
329
            )
330
            self.datasets[split] = ResamplingDataset(
331
                self.datasets[split]
332
            )
333
        else:
334
            self.datasets[split] = nest_dataset
335
336
337
    
338
339
    def load_mols_dataset(self, data_path,atoms,coords, **kwargs):
340
 
341
        dataset = LMDBDataset(data_path)
342
        label_dataset = KeyDataset(dataset, "label")
343
        dataset = AffinityMolDataset(
344
            dataset,
345
            self.args.seed,
346
            atoms,
347
            coords,
348
            False,
349
        )
350
        
351
        smi_dataset = KeyDataset(dataset, "smi")
352
353
        def PrependAndAppend(dataset, pre_token, app_token):
354
            dataset = PrependTokenDataset(dataset, pre_token)
355
            return AppendTokenDataset(dataset, app_token)
356
357
358
359
        dataset = RemoveHydrogenDataset(dataset, "atoms", "coordinates", True, True)
360
361
362
        apo_dataset = NormalizeDataset(dataset, "coordinates")
363
364
        src_dataset = KeyDataset(apo_dataset, "atoms")
365
        len_dataset = LengthDataset(src_dataset)
366
        src_dataset = TokenizeDataset(
367
            src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
368
        )
369
        coord_dataset = KeyDataset(apo_dataset, "coordinates")
370
        src_dataset = PrependAndAppend(
371
            src_dataset, self.dictionary.bos(), self.dictionary.eos()
372
        )
373
        edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
374
        coord_dataset = FromNumpyDataset(coord_dataset)
375
        distance_dataset = DistanceDataset(coord_dataset)
376
        coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
377
        distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0)
378
379
380
        nest_dataset = NestedDictionaryDataset(
381
            {
382
                "net_input": {
383
                    "mol_src_tokens": RightPadDataset(
384
                        src_dataset,
385
                        pad_idx=self.dictionary.pad(),
386
                    ),
387
                    "mol_src_distance": RightPadDataset2D(
388
                        distance_dataset,
389
                        pad_idx=0,
390
                    ),
391
                    "mol_src_edge_type": RightPadDataset2D(
392
                        edge_type,
393
                        pad_idx=0,
394
                    ),
395
                },
396
                "smi_name": RawArrayDataset(smi_dataset),
397
                "target":  RawArrayDataset(label_dataset),
398
                "mol_len": RawArrayDataset(len_dataset),
399
            },
400
        )
401
        return nest_dataset
402
    
403
404
    def load_retrieval_mols_dataset(self, data_path,atoms,coords, **kwargs):
405
 
406
        dataset = LMDBDataset(data_path)
407
        dataset = AffinityMolDataset(
408
            dataset,
409
            self.args.seed,
410
            atoms,
411
            coords,
412
            False,
413
        )
414
        
415
        smi_dataset = KeyDataset(dataset, "smi")
416
417
        def PrependAndAppend(dataset, pre_token, app_token):
418
            dataset = PrependTokenDataset(dataset, pre_token)
419
            return AppendTokenDataset(dataset, app_token)
420
421
422
423
        dataset = RemoveHydrogenDataset(dataset, "atoms", "coordinates", True, True)
424
425
426
        apo_dataset = NormalizeDataset(dataset, "coordinates")
427
428
        src_dataset = KeyDataset(apo_dataset, "atoms")
429
        len_dataset = LengthDataset(src_dataset)
430
        src_dataset = TokenizeDataset(
431
            src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
432
        )
433
        coord_dataset = KeyDataset(apo_dataset, "coordinates")
434
        src_dataset = PrependAndAppend(
435
            src_dataset, self.dictionary.bos(), self.dictionary.eos()
436
        )
437
        edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
438
        coord_dataset = FromNumpyDataset(coord_dataset)
439
        distance_dataset = DistanceDataset(coord_dataset)
440
        coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
441
        distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0)
442
443
444
        nest_dataset = NestedDictionaryDataset(
445
            {
446
                "net_input": {
447
                    "mol_src_tokens": RightPadDataset(
448
                        src_dataset,
449
                        pad_idx=self.dictionary.pad(),
450
                    ),
451
                    "mol_src_distance": RightPadDataset2D(
452
                        distance_dataset,
453
                        pad_idx=0,
454
                    ),
455
                    "mol_src_edge_type": RightPadDataset2D(
456
                        edge_type,
457
                        pad_idx=0,
458
                    ),
459
                },
460
                "smi_name": RawArrayDataset(smi_dataset),
461
                "mol_len": RawArrayDataset(len_dataset),
462
            },
463
        )
464
        return nest_dataset
465
466
    def load_pockets_dataset(self, data_path, **kwargs):
467
468
        dataset = LMDBDataset(data_path)
469
 
470
        dataset = AffinityPocketDataset(
471
            dataset,
472
            self.args.seed,
473
            "pocket_atoms",
474
            "pocket_coordinates",
475
            False,
476
            "pocket"
477
        )
478
        poc_dataset = KeyDataset(dataset, "pocket")
479
480
        def PrependAndAppend(dataset, pre_token, app_token):
481
            dataset = PrependTokenDataset(dataset, pre_token)
482
            return AppendTokenDataset(dataset, app_token)
483
484
        dataset = RemoveHydrogenPocketDataset(
485
            dataset,
486
            "pocket_atoms",
487
            "pocket_coordinates",
488
            True,
489
            True,
490
        )
491
        dataset = CroppingPocketDataset(
492
            dataset,
493
            self.seed,
494
            "pocket_atoms",
495
            "pocket_coordinates",
496
            self.args.max_pocket_atoms,
497
        )
498
499
500
501
502
        apo_dataset = NormalizeDataset(dataset, "pocket_coordinates")
503
504
505
506
        src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms")
507
        len_dataset = LengthDataset(src_pocket_dataset)
508
        src_pocket_dataset = TokenizeDataset(
509
            src_pocket_dataset,
510
            self.pocket_dictionary,
511
            max_seq_len=self.args.max_seq_len,
512
        )
513
        coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates")
514
        src_pocket_dataset = PrependAndAppend(
515
            src_pocket_dataset,
516
            self.pocket_dictionary.bos(),
517
            self.pocket_dictionary.eos(),
518
        )
519
        pocket_edge_type = EdgeTypeDataset(
520
            src_pocket_dataset, len(self.pocket_dictionary)
521
        )
522
        coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset)
523
        distance_pocket_dataset = DistanceDataset(coord_pocket_dataset)
524
        coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0)
525
        distance_pocket_dataset = PrependAndAppend2DDataset(
526
            distance_pocket_dataset, 0.0
527
        )
528
529
        nest_dataset = NestedDictionaryDataset(
530
            {
531
                "net_input": {
532
                    "pocket_src_tokens": RightPadDataset(
533
                        src_pocket_dataset,
534
                        pad_idx=self.pocket_dictionary.pad(),
535
                    ),
536
                    "pocket_src_distance": RightPadDataset2D(
537
                        distance_pocket_dataset,
538
                        pad_idx=0,
539
                    ),
540
                    "pocket_src_edge_type": RightPadDataset2D(
541
                        pocket_edge_type,
542
                        pad_idx=0,
543
                    ),
544
                    "pocket_src_coord": RightPadDatasetCoord(
545
                        coord_pocket_dataset,
546
                        pad_idx=0,
547
                    ),
548
                },
549
                "pocket_name": RawArrayDataset(poc_dataset),
550
                "pocket_len": RawArrayDataset(len_dataset),
551
            },
552
        )
553
        return nest_dataset
554
555
    
556
557
    def build_model(self, args):
558
        from unicore import models
559
560
        model = models.build_model(args, self)
561
        
562
        if args.finetune_mol_model is not None:
563
            print("load pretrain model weight from...", args.finetune_mol_model)
564
            state = checkpoint_utils.load_checkpoint_to_cpu(
565
                args.finetune_mol_model,
566
            )
567
            model.mol_model.load_state_dict(state["model"], strict=False)
568
            
569
        if args.finetune_pocket_model is not None:
570
            print("load pretrain model weight from...", args.finetune_pocket_model)
571
            state = checkpoint_utils.load_checkpoint_to_cpu(
572
                args.finetune_pocket_model,
573
            )
574
            model.pocket_model.load_state_dict(state["model"], strict=False)
575
576
        return model
577
578
    def train_step(
579
        self, sample, model, loss, optimizer, update_num, ignore_grad=False
580
    ):
581
        """
582
        Do forward and backward, and return the loss as computed by *loss*
583
        for the given *model* and *sample*.
584
585
        Args:
586
            sample (dict): the mini-batch. The format is defined by the
587
                :class:`~unicore.data.UnicoreDataset`.
588
            model (~unicore.models.BaseUnicoreModel): the model
589
            loss (~unicore.losses.UnicoreLoss): the loss
590
            optimizer (~unicore.optim.UnicoreOptimizer): the optimizer
591
            update_num (int): the current update
592
            ignore_grad (bool): multiply loss by 0 if this is set to True
593
594
        Returns:
595
            tuple:
596
                - the loss
597
                - the sample size, which is used as the denominator for the
598
                  gradient
599
                - logging outputs to display while training
600
        """
601
602
        model.train()
603
        model.set_num_updates(update_num)
604
        with torch.autograd.profiler.record_function("forward"):
605
            loss, sample_size, logging_output = loss(model, sample)
606
        if ignore_grad:
607
            loss *= 0
608
        with torch.autograd.profiler.record_function("backward"):
609
            optimizer.backward(loss)
610
        return loss, sample_size, logging_output
611
    
612
    def valid_step(self, sample, model, loss, test=False):
613
        model.eval()
614
        with torch.no_grad():
615
            loss, sample_size, logging_output = loss(model, sample)
616
        return loss, sample_size, logging_output
617
618
619
    def test_pcba_target(self, name, model, **kwargs):
620
        """Encode a dataset with the molecule encoder."""
621
622
        #names = "PPARG"
623
        data_path = "./data/lit_pcba/" + name + "/mols.lmdb"
624
        mol_dataset = self.load_mols_dataset(data_path, "atoms", "coordinates")
625
        num_data = len(mol_dataset)
626
        bsz=64
627
        #print(num_data//bsz)
628
        mol_reps = []
629
        mol_names = []
630
        labels = []
631
        
632
        # generate mol data
633
        
634
        mol_data = torch.utils.data.DataLoader(mol_dataset, batch_size=bsz, collate_fn=mol_dataset.collater)
635
        for _, sample in enumerate(tqdm(mol_data)):
636
            sample = unicore.utils.move_to_cuda(sample)
637
            dist = sample["net_input"]["mol_src_distance"]
638
            et = sample["net_input"]["mol_src_edge_type"]
639
            st = sample["net_input"]["mol_src_tokens"]
640
            mol_padding_mask = st.eq(model.mol_model.padding_idx)
641
            mol_x = model.mol_model.embed_tokens(st)
642
            
643
            n_node = dist.size(-1)
644
            gbf_feature = model.mol_model.gbf(dist, et)
645
646
            gbf_result = model.mol_model.gbf_proj(gbf_feature)
647
            graph_attn_bias = gbf_result
648
            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
649
            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
650
            mol_outputs = model.mol_model.encoder(
651
                mol_x, padding_mask=mol_padding_mask, attn_mask=graph_attn_bias
652
            )
653
            mol_encoder_rep = mol_outputs[0][:,0,:]
654
            mol_emb = model.mol_project(mol_encoder_rep)
655
            mol_emb = mol_emb / mol_emb.norm(dim=1, keepdim=True)
656
            mol_emb = mol_emb.detach().cpu().numpy()
657
            mol_reps.append(mol_emb)
658
            mol_names.extend(sample["smi_name"])
659
            labels.extend(sample["target"].detach().cpu().numpy())
660
        mol_reps = np.concatenate(mol_reps, axis=0)
661
        labels = np.array(labels, dtype=np.int32)
662
        # generate pocket data
663
        data_path = "./data/lit_pcba/" + name + "/pockets.lmdb"
664
        pocket_dataset = self.load_pockets_dataset(data_path)
665
        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=bsz, collate_fn=pocket_dataset.collater)
666
        pocket_reps = []
667
668
        for _, sample in enumerate(tqdm(pocket_data)):
669
            sample = unicore.utils.move_to_cuda(sample)
670
            dist = sample["net_input"]["pocket_src_distance"]
671
            et = sample["net_input"]["pocket_src_edge_type"]
672
            st = sample["net_input"]["pocket_src_tokens"]
673
            pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
674
            pocket_x = model.pocket_model.embed_tokens(st)
675
            n_node = dist.size(-1)
676
            gbf_feature = model.pocket_model.gbf(dist, et)
677
            gbf_result = model.pocket_model.gbf_proj(gbf_feature)
678
            graph_attn_bias = gbf_result
679
            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
680
            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
681
            pocket_outputs = model.pocket_model.encoder(
682
                pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
683
            )
684
            pocket_encoder_rep = pocket_outputs[0][:,0,:]
685
            pocket_emb = model.pocket_project(pocket_encoder_rep)
686
            pocket_emb = pocket_emb / pocket_emb.norm(dim=1, keepdim=True)
687
            pocket_emb = pocket_emb.detach().cpu().numpy()
688
            pocket_names = sample["pocket_name"]
689
            pocket_reps.append(pocket_emb)
690
        pocket_reps = np.concatenate(pocket_reps, axis=0)
691
692
        res = pocket_reps @ mol_reps.T
693
        res_single = res.max(axis=0)
694
        auc, bedroc, ef_list, re_list = cal_metrics(labels, res_single, 80.5)
695
696
        return auc, bedroc, ef_list, re_list
697
    
698
    
699
    
700
701
    def test_pcba(self, model, **kwargs):
702
        #ckpt_date = self.args.finetune_from_model.split("/")[-2]
703
        #save_name = "/home/gaobowen/DrugClip/test_results/pcba/" + ckpt_date + ".txt"
704
        save_name = ""
705
        
706
        targets = os.listdir("./data/lit_pcba/")
707
708
        #print(targets)
709
        auc_list = []
710
        ef_list = []
711
        bedroc_list = []
712
713
        re_list = {
714
            "0.005": [],
715
            "0.01": [],
716
            "0.02": [],
717
            "0.05": []
718
        }
719
        ef_list = {
720
            "0.005": [],
721
            "0.01": [],
722
            "0.02": [],
723
            "0.05": []
724
        }
725
        for target in targets:
726
            auc, bedroc, ef, re = self.test_pcba_target(target, model)
727
            auc_list.append(auc)
728
            bedroc_list.append(bedroc)
729
            for key in ef:
730
                ef_list[key].append(ef[key])
731
            # print("re", re)
732
            # print("ef", ef)
733
            for key in re:
734
                re_list[key].append(re[key])
735
        print(auc_list)
736
        print(ef_list)
737
        print("auc 25%", np.percentile(auc_list, 25))
738
        print("auc 50%", np.percentile(auc_list, 50))
739
        print("auc 75%", np.percentile(auc_list, 75))
740
        print("auc mean", np.mean(auc_list))
741
        print("bedroc 25%", np.percentile(bedroc_list, 25))
742
        print("bedroc 50%", np.percentile(bedroc_list, 50))
743
        print("bedroc 75%", np.percentile(bedroc_list, 75))
744
        print("bedroc mean", np.mean(bedroc_list))
745
        #print(np.median(auc_list))
746
        #print(np.median(ef_list))
747
        for key in ef_list:
748
            print("ef", key, "25%", np.percentile(ef_list[key], 25))
749
            print("ef",key, "50%", np.percentile(ef_list[key], 50))
750
            print("ef",key, "75%", np.percentile(ef_list[key], 75))
751
            print("ef",key, "mean", np.mean(ef_list[key]))
752
        for key in re_list:
753
            print("re",key, "25%", np.percentile(re_list[key], 25))
754
            print("re",key, "50%", np.percentile(re_list[key], 50))
755
            print("re",key, "75%", np.percentile(re_list[key], 75))
756
            print("re",key, "mean", np.mean(re_list[key]))
757
758
        return 
759
    
760
    def test_dude_target(self, target, model, **kwargs):
761
762
        data_path = "./data/DUD-E/raw/all/" + target + "/mols.lmdb"
763
        mol_dataset = self.load_mols_dataset(data_path, "atoms", "coordinates")
764
        num_data = len(mol_dataset)
765
        bsz=64
766
        print(num_data//bsz)
767
        mol_reps = []
768
        mol_names = []
769
        labels = []
770
        
771
        # generate mol data
772
        
773
        mol_data = torch.utils.data.DataLoader(mol_dataset, batch_size=bsz, collate_fn=mol_dataset.collater)
774
        for _, sample in enumerate(tqdm(mol_data)):
775
            sample = unicore.utils.move_to_cuda(sample)
776
            dist = sample["net_input"]["mol_src_distance"]
777
            et = sample["net_input"]["mol_src_edge_type"]
778
            st = sample["net_input"]["mol_src_tokens"]
779
            mol_padding_mask = st.eq(model.mol_model.padding_idx)
780
            mol_x = model.mol_model.embed_tokens(st)
781
            n_node = dist.size(-1)
782
            gbf_feature = model.mol_model.gbf(dist, et)
783
            gbf_result = model.mol_model.gbf_proj(gbf_feature)
784
            graph_attn_bias = gbf_result
785
            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
786
            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
787
            mol_outputs = model.mol_model.encoder(
788
                mol_x, padding_mask=mol_padding_mask, attn_mask=graph_attn_bias
789
            )
790
            mol_encoder_rep = mol_outputs[0][:,0,:]
791
            mol_emb = mol_encoder_rep
792
            mol_emb = model.mol_project(mol_encoder_rep)
793
            mol_emb = mol_emb / mol_emb.norm(dim=-1, keepdim=True)
794
            #print(mol_emb.dtype)
795
            mol_emb = mol_emb.detach().cpu().numpy()
796
            #print(mol_emb.dtype)
797
            mol_reps.append(mol_emb)
798
            mol_names.extend(sample["smi_name"])
799
            labels.extend(sample["target"].detach().cpu().numpy())
800
        mol_reps = np.concatenate(mol_reps, axis=0)
801
        labels = np.array(labels, dtype=np.int32)
802
        # generate pocket data
803
        data_path = "./data/DUD-E/raw/all/" + target + "/pocket.lmdb"
804
        pocket_dataset = self.load_pockets_dataset(data_path)
805
        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=bsz, collate_fn=pocket_dataset.collater)
806
        pocket_reps = []
807
808
        for _, sample in enumerate(tqdm(pocket_data)):
809
            sample = unicore.utils.move_to_cuda(sample)
810
            dist = sample["net_input"]["pocket_src_distance"]
811
            et = sample["net_input"]["pocket_src_edge_type"]
812
            st = sample["net_input"]["pocket_src_tokens"]
813
            pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
814
            pocket_x = model.pocket_model.embed_tokens(st)
815
            n_node = dist.size(-1)
816
            gbf_feature = model.pocket_model.gbf(dist, et)
817
            gbf_result = model.pocket_model.gbf_proj(gbf_feature)
818
            graph_attn_bias = gbf_result
819
            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
820
            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
821
            pocket_outputs = model.pocket_model.encoder(
822
                pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
823
            )
824
            pocket_encoder_rep = pocket_outputs[0][:,0,:]
825
            #pocket_emb = pocket_encoder_rep
826
            pocket_emb = model.pocket_project(pocket_encoder_rep)
827
            pocket_emb = pocket_emb / pocket_emb.norm(dim=-1, keepdim=True)
828
            pocket_emb = pocket_emb.detach().cpu().numpy()
829
            pocket_reps.append(pocket_emb)
830
        pocket_reps = np.concatenate(pocket_reps, axis=0)
831
        print(pocket_reps.shape)
832
        res = pocket_reps @ mol_reps.T
833
834
        res_single = res.max(axis=0)
835
836
        auc, bedroc, ef_list, re_list = cal_metrics(labels, res_single, 80.5)
837
        
838
        
839
        print(target)
840
841
        print(np.sum(labels), len(labels)-np.sum(labels))
842
843
        return auc, bedroc, ef_list, re_list, res_single, labels
844
845
    def test_dude(self, model, **kwargs):
846
847
848
        targets = os.listdir("./data/DUD-E/raw/all/")
849
        auc_list = []
850
        bedroc_list = []
851
        ef_list = []
852
        res_list= []
853
        labels_list = []
854
        re_list = {
855
            "0.005": [],
856
            "0.01": [],
857
            "0.02": [],
858
            "0.05": [],
859
        }
860
        ef_list = {
861
            "0.005": [],
862
            "0.01": [],
863
            "0.02": [],
864
            "0.05": [],
865
        }
866
        for i,target in enumerate(targets):
867
            auc, bedroc, ef, re, res_single, labels = self.test_dude_target(target, model)
868
            auc_list.append(auc)
869
            bedroc_list.append(bedroc)
870
            for key in ef:
871
                ef_list[key].append(ef[key])
872
            for key in re_list:
873
                re_list[key].append(re[key])
874
            res_list.append(res_single)
875
            labels_list.append(labels)
876
        res = np.concatenate(res_list, axis=0)
877
        labels = np.concatenate(labels_list, axis=0)
878
        print("auc mean", np.mean(auc_list))
879
        print("bedroc mean", np.mean(bedroc_list))
880
881
        for key in ef_list:
882
            print("ef", key, "mean", np.mean(ef_list[key]))
883
884
        for key in re_list:
885
            print("re", key, "mean",  np.mean(re_list[key]))
886
887
        # save printed results 
888
        
889
        
890
        return
891
    
892
    
893
    
894
    
895
    
896
    def encode_mols_once(self, model, data_path, emb_dir, atoms, coords, **kwargs):
897
        
898
        # cache path is embdir/data_path.pkl
899
900
        cache_path = os.path.join(emb_dir, data_path.split("/")[-1] + ".pkl")
901
902
        if os.path.exists(cache_path):
903
            with open(cache_path, "rb") as f:
904
                mol_reps, mol_names = pickle.load(f)
905
            return mol_reps, mol_names
906
907
        mol_dataset = self.load_retrieval_mols_dataset(data_path,atoms,coords)
908
        mol_reps = []
909
        mol_names = []
910
        bsz=32
911
        mol_data = torch.utils.data.DataLoader(mol_dataset, batch_size=bsz, collate_fn=mol_dataset.collater)
912
        for _, sample in enumerate(tqdm(mol_data)):
913
            sample = unicore.utils.move_to_cuda(sample)
914
            dist = sample["net_input"]["mol_src_distance"]
915
            et = sample["net_input"]["mol_src_edge_type"]
916
            st = sample["net_input"]["mol_src_tokens"]
917
            mol_padding_mask = st.eq(model.mol_model.padding_idx)
918
            mol_x = model.mol_model.embed_tokens(st)
919
            n_node = dist.size(-1)
920
            gbf_feature = model.mol_model.gbf(dist, et)
921
            gbf_result = model.mol_model.gbf_proj(gbf_feature)
922
            graph_attn_bias = gbf_result
923
            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
924
            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
925
            mol_outputs = model.mol_model.encoder(
926
                mol_x, padding_mask=mol_padding_mask, attn_mask=graph_attn_bias
927
            )
928
            mol_encoder_rep = mol_outputs[0][:,0,:]
929
            mol_emb = model.mol_project(mol_encoder_rep)
930
            mol_emb = mol_emb / mol_emb.norm(dim=-1, keepdim=True)
931
            mol_emb = mol_emb.detach().cpu().numpy()
932
            mol_reps.append(mol_emb)
933
            mol_names.extend(sample["smi_name"])
934
935
        mol_reps = np.concatenate(mol_reps, axis=0)
936
937
        # save the results
938
        
939
        with open(cache_path, "wb") as f:
940
            pickle.dump([mol_reps, mol_names], f)
941
942
        return mol_reps, mol_names
943
    
944
    def retrieve_mols(self, model, mol_path, pocket_path, emb_dir, k, **kwargs):
945
 
946
        os.makedirs(emb_dir, exist_ok=True)        
947
        mol_reps, mol_names = self.encode_mols_once(model, mol_path, emb_dir,  "atoms", "coordinates")
948
        
949
        pocket_dataset = self.load_pockets_dataset(pocket_path)
950
        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=16, collate_fn=pocket_dataset.collater)
951
        pocket_reps = []
952
        pocket_names = []
953
        for _, sample in enumerate(tqdm(pocket_data)):
954
            sample = unicore.utils.move_to_cuda(sample)
955
            dist = sample["net_input"]["pocket_src_distance"]
956
            et = sample["net_input"]["pocket_src_edge_type"]
957
            st = sample["net_input"]["pocket_src_tokens"]
958
            pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
959
            pocket_x = model.pocket_model.embed_tokens(st)
960
            n_node = dist.size(-1)
961
            gbf_feature = model.pocket_model.gbf(dist, et)
962
            gbf_result = model.pocket_model.gbf_proj(gbf_feature)
963
            graph_attn_bias = gbf_result
964
            graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
965
            graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
966
            pocket_outputs = model.pocket_model.encoder(
967
                pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
968
            )
969
            pocket_encoder_rep = pocket_outputs[0][:,0,:]
970
            pocket_emb = model.pocket_project(pocket_encoder_rep)
971
            pocket_emb = pocket_emb / pocket_emb.norm(dim=-1, keepdim=True)
972
            pocket_emb = pocket_emb.detach().cpu().numpy()
973
            pocket_reps.append(pocket_emb)
974
            pocket_names.extend(sample["pocket_name"])
975
        pocket_reps = np.concatenate(pocket_reps, axis=0)
976
        
977
        res = pocket_reps @ mol_reps.T
978
        res = res.max(axis=0)
979
980
981
        # get top k results
982
983
        
984
        top_k = np.argsort(res)[::-1][:k]
985
986
        # return names and scores
987
        
988
        return [mol_names[i] for i in top_k], res[top_k]
989
990
991
        
992
993
        
994
         
995
996
997
    
998
999
    
1000
1001
        
1002
            
1003
         
1004
1005
        
1006
    
1007