[8956d4]: / unimol / retrieval.py

Download this file

81 lines (57 with data), 2.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3 -u
# Copyright (c) DP Techonology, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
import pickle
import torch
from unicore import checkpoint_utils, distributed_utils, options, utils
from unicore.logging import progress_bar
from unicore import tasks
import numpy as np
from tqdm import tqdm
import unicore
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("unimol.inference")
#from skchem.metrics import bedroc_score
from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
from sklearn.metrics import roc_curve
def main(args):
use_fp16 = args.fp16
use_cuda = torch.cuda.is_available() and not args.cpu
if use_cuda:
torch.cuda.set_device(args.device_id)
# Load model
logger.info("loading model(s) from {}".format(args.path))
state = checkpoint_utils.load_checkpoint_to_cpu(args.path)
task = tasks.setup_task(args)
model = task.build_model(args)
model.load_state_dict(state["model"], strict=False)
# Move models to GPU
if use_fp16:
model.half()
if use_cuda:
model.cuda()
# Print args
logger.info(args)
model.eval()
names, scores = task.retrieve_mols(model, args.mol_path, args.pocket_path, args.emb_dir, 10000)
def cli_main():
# add args
parser = options.get_validation_parser()
parser.add_argument("--mol-path", type=str, default="", help="path for mol data")
parser.add_argument("--pocket-path", type=str, default="", help="path for pocket data")
parser.add_argument("--emb-dir", type=str, default="", help="path for saved embedding data")
options.add_model_args(parser)
args = options.parse_args_and_arch(parser)
distributed_utils.call_main(args, main)
if __name__ == "__main__":
cli_main()