|
a |
|
b/unimol/utils/docking_utils.py |
|
|
1 |
# Copyright (c) DP Techonology, Inc. and its affiliates. |
|
|
2 |
# |
|
|
3 |
# This source code is licensed under the MIT license found in the |
|
|
4 |
# LICENSE file in the root directory of this source tree. |
|
|
5 |
|
|
|
6 |
import numpy as np |
|
|
7 |
from rdkit import Chem |
|
|
8 |
from rdkit.Chem import AllChem |
|
|
9 |
from rdkit import RDLogger |
|
|
10 |
|
|
|
11 |
RDLogger.DisableLog("rdApp.*") |
|
|
12 |
import warnings |
|
|
13 |
|
|
|
14 |
warnings.filterwarnings(action="ignore") |
|
|
15 |
from rdkit.Chem import rdMolTransforms |
|
|
16 |
import copy |
|
|
17 |
import lmdb |
|
|
18 |
import pickle |
|
|
19 |
import pandas as pd |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def get_torsions(m, removeHs=True): |
|
|
23 |
if removeHs: |
|
|
24 |
m = Chem.RemoveHs(m) |
|
|
25 |
torsionList = [] |
|
|
26 |
torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]" |
|
|
27 |
torsionQuery = Chem.MolFromSmarts(torsionSmarts) |
|
|
28 |
matches = m.GetSubstructMatches(torsionQuery) |
|
|
29 |
for match in matches: |
|
|
30 |
idx2 = match[0] |
|
|
31 |
idx3 = match[1] |
|
|
32 |
bond = m.GetBondBetweenAtoms(idx2, idx3) |
|
|
33 |
jAtom = m.GetAtomWithIdx(idx2) |
|
|
34 |
kAtom = m.GetAtomWithIdx(idx3) |
|
|
35 |
for b1 in jAtom.GetBonds(): |
|
|
36 |
if b1.GetIdx() == bond.GetIdx(): |
|
|
37 |
continue |
|
|
38 |
idx1 = b1.GetOtherAtomIdx(idx2) |
|
|
39 |
for b2 in kAtom.GetBonds(): |
|
|
40 |
if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()): |
|
|
41 |
continue |
|
|
42 |
idx4 = b2.GetOtherAtomIdx(idx3) |
|
|
43 |
# skip 3-membered rings |
|
|
44 |
if idx4 == idx1: |
|
|
45 |
continue |
|
|
46 |
# skip torsions that include hydrogens |
|
|
47 |
if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or ( |
|
|
48 |
m.GetAtomWithIdx(idx4).GetAtomicNum() == 1 |
|
|
49 |
): |
|
|
50 |
continue |
|
|
51 |
if m.GetAtomWithIdx(idx4).IsInRing(): |
|
|
52 |
torsionList.append((idx4, idx3, idx2, idx1)) |
|
|
53 |
break |
|
|
54 |
else: |
|
|
55 |
torsionList.append((idx1, idx2, idx3, idx4)) |
|
|
56 |
break |
|
|
57 |
break |
|
|
58 |
return torsionList |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def SetDihedral(conf, atom_idx, new_vale): |
|
|
62 |
rdMolTransforms.SetDihedralRad( |
|
|
63 |
conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale |
|
|
64 |
) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def single_conf_gen_bonds(tgt_mol, num_confs=1000, seed=42, removeHs=True): |
|
|
68 |
mol = copy.deepcopy(tgt_mol) |
|
|
69 |
mol = Chem.AddHs(mol) |
|
|
70 |
allconformers = AllChem.EmbedMultipleConfs( |
|
|
71 |
mol, numConfs=num_confs, randomSeed=seed, clearConfs=True |
|
|
72 |
) |
|
|
73 |
if removeHs: |
|
|
74 |
mol = Chem.RemoveHs(mol) |
|
|
75 |
rotable_bonds = get_torsions(mol, removeHs=removeHs) |
|
|
76 |
for i in range(len(allconformers)): |
|
|
77 |
np.random.seed(i) |
|
|
78 |
values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds)) |
|
|
79 |
for idx in range(len(rotable_bonds)): |
|
|
80 |
SetDihedral(mol.GetConformers()[i], rotable_bonds[idx], values[idx]) |
|
|
81 |
Chem.rdMolTransforms.CanonicalizeConformer(mol.GetConformers()[i]) |
|
|
82 |
return mol |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
def load_lmdb_data(lmdb_path, key): |
|
|
86 |
env = lmdb.open( |
|
|
87 |
lmdb_path, |
|
|
88 |
subdir=False, |
|
|
89 |
readonly=True, |
|
|
90 |
lock=False, |
|
|
91 |
readahead=False, |
|
|
92 |
meminit=False, |
|
|
93 |
max_readers=256, |
|
|
94 |
) |
|
|
95 |
txn = env.begin() |
|
|
96 |
_keys = list(txn.cursor().iternext(values=False)) |
|
|
97 |
collects = [] |
|
|
98 |
for idx in range(len(_keys)): |
|
|
99 |
datapoint_pickled = txn.get(f"{idx}".encode("ascii")) |
|
|
100 |
data = pickle.loads(datapoint_pickled) |
|
|
101 |
collects.append(data[key]) |
|
|
102 |
return collects |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
def docking_data_pre(raw_data_path, predict_path): |
|
|
106 |
|
|
|
107 |
mol_list = load_lmdb_data(raw_data_path, "mol_list") |
|
|
108 |
mol_list = [Chem.RemoveHs(mol) for items in mol_list for mol in items] |
|
|
109 |
predict = pd.read_pickle(predict_path) |
|
|
110 |
( |
|
|
111 |
smi_list, |
|
|
112 |
pocket_list, |
|
|
113 |
pocket_coords_list, |
|
|
114 |
distance_predict_list, |
|
|
115 |
holo_distance_predict_list, |
|
|
116 |
holo_coords_list, |
|
|
117 |
holo_center_coords_list, |
|
|
118 |
) = ([], [], [], [], [], [], []) |
|
|
119 |
for batch in predict: |
|
|
120 |
sz = batch["atoms"].size(0) |
|
|
121 |
for i in range(sz): |
|
|
122 |
smi_list.append(batch["smi_name"][i]) |
|
|
123 |
pocket_list.append(batch["pocket_name"][i]) |
|
|
124 |
|
|
|
125 |
distance_predict = batch["cross_distance_predict"][i] |
|
|
126 |
token_mask = batch["atoms"][i] > 2 |
|
|
127 |
pocket_token_mask = batch["pocket_atoms"][i] > 2 |
|
|
128 |
distance_predict = distance_predict[token_mask][:, pocket_token_mask] |
|
|
129 |
pocket_coords = batch["pocket_coordinates"][i] |
|
|
130 |
pocket_coords = pocket_coords[pocket_token_mask, :] |
|
|
131 |
|
|
|
132 |
holo_distance_predict = batch["holo_distance_predict"][i] |
|
|
133 |
holo_distance_predict = holo_distance_predict[token_mask][:, token_mask] |
|
|
134 |
|
|
|
135 |
holo_coordinates = batch["holo_coordinates"][i] |
|
|
136 |
holo_coordinates = holo_coordinates[token_mask, :] |
|
|
137 |
holo_center_coordinates = batch["holo_center_coordinates"][i][:3] |
|
|
138 |
|
|
|
139 |
pocket_coords = pocket_coords.numpy().astype(np.float32) |
|
|
140 |
distance_predict = distance_predict.numpy().astype(np.float32) |
|
|
141 |
holo_distance_predict = holo_distance_predict.numpy().astype(np.float32) |
|
|
142 |
holo_coords = holo_coordinates.numpy().astype(np.float32) |
|
|
143 |
|
|
|
144 |
pocket_coords_list.append(pocket_coords) |
|
|
145 |
distance_predict_list.append(distance_predict) |
|
|
146 |
holo_distance_predict_list.append(holo_distance_predict) |
|
|
147 |
holo_coords_list.append(holo_coords) |
|
|
148 |
holo_center_coords_list.append(holo_center_coordinates) |
|
|
149 |
|
|
|
150 |
return ( |
|
|
151 |
mol_list, |
|
|
152 |
smi_list, |
|
|
153 |
pocket_list, |
|
|
154 |
pocket_coords_list, |
|
|
155 |
distance_predict_list, |
|
|
156 |
holo_distance_predict_list, |
|
|
157 |
holo_coords_list, |
|
|
158 |
holo_center_coords_list, |
|
|
159 |
) |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def ensemble_iterations( |
|
|
163 |
mol_list, |
|
|
164 |
smi_list, |
|
|
165 |
pocket_list, |
|
|
166 |
pocket_coords_list, |
|
|
167 |
distance_predict_list, |
|
|
168 |
holo_distance_predict_list, |
|
|
169 |
holo_coords_list, |
|
|
170 |
holo_center_coords_list, |
|
|
171 |
tta_times=10, |
|
|
172 |
): |
|
|
173 |
sz = len(mol_list) |
|
|
174 |
for i in range(sz // tta_times): |
|
|
175 |
start_idx, end_idx = i * tta_times, (i + 1) * tta_times |
|
|
176 |
distance_predict_tta = distance_predict_list[start_idx:end_idx] |
|
|
177 |
holo_distance_predict_tta = holo_distance_predict_list[start_idx:end_idx] |
|
|
178 |
|
|
|
179 |
mol = copy.deepcopy(mol_list[start_idx]) |
|
|
180 |
rdkit_mol = single_conf_gen_bonds( |
|
|
181 |
mol, num_confs=tta_times, seed=42, removeHs=True |
|
|
182 |
) |
|
|
183 |
sz = len(rdkit_mol.GetConformers()) |
|
|
184 |
initial_coords_list = [ |
|
|
185 |
rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32) |
|
|
186 |
for i in range(sz) |
|
|
187 |
] |
|
|
188 |
|
|
|
189 |
yield [ |
|
|
190 |
initial_coords_list, |
|
|
191 |
mol, |
|
|
192 |
smi_list[start_idx], |
|
|
193 |
pocket_list[start_idx], |
|
|
194 |
pocket_coords_list[start_idx], |
|
|
195 |
distance_predict_tta, |
|
|
196 |
holo_distance_predict_tta, |
|
|
197 |
holo_coords_list[start_idx], |
|
|
198 |
holo_center_coords_list[start_idx], |
|
|
199 |
] |
|
|
200 |
|
|
|
201 |
|
|
|
202 |
def rmsd_func(holo_coords, predict_coords): |
|
|
203 |
if predict_coords is not np.nan: |
|
|
204 |
sz = holo_coords.shape |
|
|
205 |
rmsd = np.sqrt(np.sum((predict_coords - holo_coords) ** 2) / sz[0]) |
|
|
206 |
return rmsd |
|
|
207 |
return 1000.0 |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
def print_results(rmsd_results): |
|
|
211 |
print("RMSD < 1.0 : ", np.mean(rmsd_results < 1.0)) |
|
|
212 |
print("RMSD < 1.5 : ", np.mean(rmsd_results < 1.5)) |
|
|
213 |
print("RMSD < 2.0 : ", np.mean(rmsd_results < 2.0)) |
|
|
214 |
print("RMSD < 3.0 : ", np.mean(rmsd_results < 3.0)) |
|
|
215 |
print("RMSD < 5.0 : ", np.mean(rmsd_results < 5.0)) |
|
|
216 |
print("avg RMSD : ", np.mean(rmsd_results)) |