Diff of /unimol/test.py [000000] .. [b40915]

Switch to unified view

a b/unimol/test.py
1
#!/usr/bin/env python3 -u
2
# Copyright (c) DP Techonology, Inc. and its affiliates.
3
#
4
# This source code is licensed under the MIT license found in the
5
# LICENSE file in the root directory of this source tree.
6
7
import logging
8
import os
9
import sys
10
import pickle
11
import torch
12
from unicore import checkpoint_utils, distributed_utils, options, utils
13
from unicore.logging import progress_bar
14
from unicore import tasks
15
import numpy as np
16
from tqdm import tqdm
17
import unicore
18
19
logging.basicConfig(
20
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
21
    datefmt="%Y-%m-%d %H:%M:%S",
22
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
23
    stream=sys.stdout,
24
)
25
logger = logging.getLogger("unimol.inference")
26
27
28
#from skchem.metrics import bedroc_score
29
from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
30
from sklearn.metrics import roc_curve
31
32
33
34
def main(args):
35
36
    use_fp16 = args.fp16
37
    use_cuda = torch.cuda.is_available() and not args.cpu
38
39
    if use_cuda:
40
        torch.cuda.set_device(args.device_id)
41
42
43
    # Load model
44
    logger.info("loading model(s) from {}".format(args.path))
45
    state = checkpoint_utils.load_checkpoint_to_cpu(args.path)
46
    task = tasks.setup_task(args)
47
    model = task.build_model(args)
48
    model.load_state_dict(state["model"], strict=False)
49
50
    # Move models to GPU
51
    if use_fp16:
52
        model.half()
53
    if use_cuda:
54
        model.cuda()
55
56
    # Print args
57
    logger.info(args)
58
59
60
    model.eval()
61
    if args.test_task=="DUDE":
62
        task.test_dude(model)
63
64
    elif args.test_task=="PCBA":
65
        task.test_pcba(model)
66
67
68
def cli_main():
69
    # add args
70
    
71
72
    parser = options.get_validation_parser()
73
    parser.add_argument("--test-task", type=str, default="DUDE", help="test task", choices=["DUDE", "PCBA"])
74
    options.add_model_args(parser)
75
    args = options.parse_args_and_arch(parser)
76
77
    distributed_utils.call_main(args, main)
78
79
80
if __name__ == "__main__":
81
    cli_main()