--- a
+++ b/unimol/test.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python3 -u
+# Copyright (c) DP Techonology, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import sys
+import pickle
+import torch
+from unicore import checkpoint_utils, distributed_utils, options, utils
+from unicore.logging import progress_bar
+from unicore import tasks
+import numpy as np
+from tqdm import tqdm
+import unicore
+
+logging.basicConfig(
+    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+    datefmt="%Y-%m-%d %H:%M:%S",
+    level=os.environ.get("LOGLEVEL", "INFO").upper(),
+    stream=sys.stdout,
+)
+logger = logging.getLogger("unimol.inference")
+
+
+#from skchem.metrics import bedroc_score
+from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
+from sklearn.metrics import roc_curve
+
+
+
+def main(args):
+
+    use_fp16 = args.fp16
+    use_cuda = torch.cuda.is_available() and not args.cpu
+
+    if use_cuda:
+        torch.cuda.set_device(args.device_id)
+
+
+    # Load model
+    logger.info("loading model(s) from {}".format(args.path))
+    state = checkpoint_utils.load_checkpoint_to_cpu(args.path)
+    task = tasks.setup_task(args)
+    model = task.build_model(args)
+    model.load_state_dict(state["model"], strict=False)
+
+    # Move models to GPU
+    if use_fp16:
+        model.half()
+    if use_cuda:
+        model.cuda()
+
+    # Print args
+    logger.info(args)
+
+
+    model.eval()
+    if args.test_task=="DUDE":
+        task.test_dude(model)
+
+    elif args.test_task=="PCBA":
+        task.test_pcba(model)
+
+
+def cli_main():
+    # add args
+    
+
+    parser = options.get_validation_parser()
+    parser.add_argument("--test-task", type=str, default="DUDE", help="test task", choices=["DUDE", "PCBA"])
+    options.add_model_args(parser)
+    args = options.parse_args_and_arch(parser)
+
+    distributed_utils.call_main(args, main)
+
+
+if __name__ == "__main__":
+    cli_main()