Diff of /unimol/utils/docking.py [000000] .. [b40915]

Switch to unified view

a b/unimol/utils/docking.py
1
# Copyright (c) DP Techonology, Inc. and its affiliates.
2
#
3
# This source code is licensed under the MIT license found in the
4
# LICENSE file in the root directory of this source tree.
5
6
import os
7
import numpy as np
8
import pandas as pd
9
from multiprocessing import Pool
10
from tqdm import tqdm
11
import glob
12
import argparse
13
from docking_utils import (
14
    docking_data_pre,
15
    ensemble_iterations,
16
    print_results,
17
    rmsd_func,
18
)
19
import warnings
20
21
warnings.filterwarnings(action="ignore")
22
23
24
def result_log(dir_path):
25
    ### result logging ###
26
    output_dir = os.path.join(dir_path, "cache")
27
    rmsd_results = []
28
    for path in glob.glob(os.path.join(output_dir, "*.docking.pkl")):
29
        (
30
            bst_predict_coords,
31
            holo_coords,
32
            bst_loss,
33
            smi,
34
            pocket,
35
            pocket_coords,
36
        ) = pd.read_pickle(path)
37
        rmsd = rmsd_func(holo_coords, bst_predict_coords)
38
        rmsd_results.append(rmsd)
39
    rmsd_results = np.array(rmsd_results)
40
    print_results(rmsd_results)
41
42
43
if __name__ == "__main__":
44
    parser = argparse.ArgumentParser(description="docking")
45
    parser.add_argument(
46
        "--reference-file",
47
        type=str,
48
        default="./protein_ligand_binding_pose_prediction/test.lmdb",
49
        help="Location of the reference set",
50
    )
51
    parser.add_argument("--nthreads", type=int, default=40, help="num of threads")
52
    parser.add_argument(
53
        "--predict-file",
54
        type=str,
55
        default="./infer_pose/save_pose_test.out.pkl",
56
        help="Location of the prediction file",
57
    )
58
    parser.add_argument(
59
        "--output-path",
60
        type=str,
61
        default="./protein_ligand_binding_pose_prediction",
62
        help="Location of the docking output path",
63
    )
64
    args = parser.parse_args()
65
66
    raw_data_path, predict_path, dir_path, nthreads = (
67
        args.reference_file,
68
        args.predict_file,
69
        args.output_path,
70
        args.nthreads,
71
    )
72
    tta_times = 10
73
    (
74
        mol_list,
75
        smi_list,
76
        pocket_list,
77
        pocket_coords_list,
78
        distance_predict_list,
79
        holo_distance_predict_list,
80
        holo_coords_list,
81
        holo_center_coords_list,
82
    ) = docking_data_pre(raw_data_path, predict_path)
83
    iterations = ensemble_iterations(
84
        mol_list,
85
        smi_list,
86
        pocket_list,
87
        pocket_coords_list,
88
        distance_predict_list,
89
        holo_distance_predict_list,
90
        holo_coords_list,
91
        holo_center_coords_list,
92
        tta_times=tta_times,
93
    )
94
    sz = len(mol_list) // tta_times
95
    new_pocket_list = pocket_list[::tta_times]
96
    output_dir = os.path.join(dir_path, "cache")
97
    os.makedirs(output_dir, exist_ok=True)
98
99
    def dump(content):
100
        pocket = content[3]
101
        output_name = os.path.join(output_dir, "{}.pkl".format(pocket))
102
        try:
103
            os.remove(output_name)
104
        except:
105
            pass
106
        pd.to_pickle(content, output_name)
107
        return True
108
109
    with Pool(nthreads) as pool:
110
        for inner_output in tqdm(pool.imap(dump, iterations), total=sz):
111
            if not inner_output:
112
                print("fail to dump")
113
114
    def single_docking(pocket_name):
115
        input_name = os.path.join(output_dir, "{}.pkl".format(pocket_name))
116
        output_name = os.path.join(output_dir, "{}.docking.pkl".format(pocket_name))
117
        output_ligand_name = os.path.join(
118
            output_dir, "{}.ligand.sdf".format(pocket_name)
119
        )
120
        try:
121
            os.remove(output_name)
122
        except:
123
            pass
124
        try:
125
            os.remove(output_ligand_name)
126
        except:
127
            pass
128
        cmd = "python ./unimol/utils/coordinate_model.py --input {} --output {} --output-ligand {}".format(
129
            input_name, output_name, output_ligand_name
130
        )
131
        os.system(cmd)
132
        return True
133
134
    with Pool(nthreads) as pool:
135
        for inner_output in tqdm(
136
            pool.imap(single_docking, new_pocket_list), total=len(new_pocket_list)
137
        ):
138
            if not inner_output:
139
                print("fail to docking")
140
141
    result_log(args.output_path)