|
a |
|
b/unimol/utils/coordinate_model.py |
|
|
1 |
# Copyright (c) DP Techonology, Inc. and its affiliates. |
|
|
2 |
# |
|
|
3 |
# This source code is licensed under the MIT license found in the |
|
|
4 |
# LICENSE file in the root directory of this source tree. |
|
|
5 |
|
|
|
6 |
import copy |
|
|
7 |
import torch |
|
|
8 |
import pandas as pd |
|
|
9 |
from rdkit import Chem |
|
|
10 |
import pickle |
|
|
11 |
import argparse |
|
|
12 |
from docking_utils import rmsd_func |
|
|
13 |
import warnings |
|
|
14 |
|
|
|
15 |
warnings.filterwarnings(action="ignore") |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def single_SF_loss( |
|
|
19 |
predict_coords, |
|
|
20 |
pocket_coords, |
|
|
21 |
distance_predict, |
|
|
22 |
holo_distance_predict, |
|
|
23 |
dist_threshold=4.5, |
|
|
24 |
): |
|
|
25 |
dist = torch.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1) |
|
|
26 |
holo_dist = torch.norm( |
|
|
27 |
predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1 |
|
|
28 |
) |
|
|
29 |
distance_mask = distance_predict < dist_threshold |
|
|
30 |
cross_dist_score = ( |
|
|
31 |
(dist[distance_mask] - distance_predict[distance_mask]) ** 2 |
|
|
32 |
).mean() |
|
|
33 |
dist_score = ((holo_dist - holo_distance_predict) ** 2).mean() |
|
|
34 |
loss = cross_dist_score * 1.0 + dist_score * 5.0 |
|
|
35 |
return loss |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def scoring( |
|
|
39 |
predict_coords, |
|
|
40 |
pocket_coords, |
|
|
41 |
distance_predict, |
|
|
42 |
holo_distance_predict, |
|
|
43 |
dist_threshold=4.5, |
|
|
44 |
): |
|
|
45 |
predict_coords = predict_coords.detach() |
|
|
46 |
dist = torch.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1) |
|
|
47 |
holo_dist = torch.norm( |
|
|
48 |
predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1 |
|
|
49 |
) |
|
|
50 |
distance_mask = distance_predict < dist_threshold |
|
|
51 |
cross_dist_score = ( |
|
|
52 |
(dist[distance_mask] - distance_predict[distance_mask]) ** 2 |
|
|
53 |
).mean() |
|
|
54 |
dist_score = ((holo_dist - holo_distance_predict) ** 2).mean() |
|
|
55 |
return cross_dist_score.numpy(), dist_score.numpy() |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def dock_with_gradient( |
|
|
59 |
coords, |
|
|
60 |
pocket_coords, |
|
|
61 |
distance_predict_tta, |
|
|
62 |
holo_distance_predict_tta, |
|
|
63 |
loss_func=single_SF_loss, |
|
|
64 |
holo_coords=None, |
|
|
65 |
iterations=20000, |
|
|
66 |
early_stoping=5, |
|
|
67 |
): |
|
|
68 |
bst_loss, bst_coords, bst_meta_info = 10000.0, coords, None |
|
|
69 |
for i, (distance_predict, holo_distance_predict) in enumerate( |
|
|
70 |
zip(distance_predict_tta, holo_distance_predict_tta) |
|
|
71 |
): |
|
|
72 |
new_coords = copy.deepcopy(coords) |
|
|
73 |
_coords, _loss, _meta_info = single_dock_with_gradient( |
|
|
74 |
new_coords, |
|
|
75 |
pocket_coords, |
|
|
76 |
distance_predict, |
|
|
77 |
holo_distance_predict, |
|
|
78 |
loss_func=loss_func, |
|
|
79 |
holo_coords=holo_coords, |
|
|
80 |
iterations=iterations, |
|
|
81 |
early_stoping=early_stoping, |
|
|
82 |
) |
|
|
83 |
if bst_loss > _loss: |
|
|
84 |
bst_coords = _coords |
|
|
85 |
bst_loss = _loss |
|
|
86 |
bst_meta_info = _meta_info |
|
|
87 |
return bst_coords, bst_loss, bst_meta_info |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
def single_dock_with_gradient( |
|
|
91 |
coords, |
|
|
92 |
pocket_coords, |
|
|
93 |
distance_predict, |
|
|
94 |
holo_distance_predict, |
|
|
95 |
loss_func=single_SF_loss, |
|
|
96 |
holo_coords=None, |
|
|
97 |
iterations=20000, |
|
|
98 |
early_stoping=5, |
|
|
99 |
): |
|
|
100 |
coords = torch.from_numpy(coords).float() |
|
|
101 |
pocket_coords = torch.from_numpy(pocket_coords).float() |
|
|
102 |
distance_predict = torch.from_numpy(distance_predict).float() |
|
|
103 |
holo_distance_predict = torch.from_numpy(holo_distance_predict).float() |
|
|
104 |
|
|
|
105 |
if holo_coords is not None: |
|
|
106 |
holo_coords = torch.from_numpy(holo_coords).float() |
|
|
107 |
|
|
|
108 |
coords.requires_grad = True |
|
|
109 |
optimizer = torch.optim.LBFGS([coords], lr=1.0) |
|
|
110 |
bst_loss, times = 10000.0, 0 |
|
|
111 |
for i in range(iterations): |
|
|
112 |
|
|
|
113 |
def closure(): |
|
|
114 |
optimizer.zero_grad() |
|
|
115 |
loss = loss_func( |
|
|
116 |
coords, pocket_coords, distance_predict, holo_distance_predict |
|
|
117 |
) |
|
|
118 |
loss.backward() |
|
|
119 |
return loss |
|
|
120 |
|
|
|
121 |
loss = optimizer.step(closure) |
|
|
122 |
if loss.item() < bst_loss: |
|
|
123 |
bst_loss = loss.item() |
|
|
124 |
times = 0 |
|
|
125 |
else: |
|
|
126 |
times += 1 |
|
|
127 |
if times > early_stoping: |
|
|
128 |
break |
|
|
129 |
|
|
|
130 |
meta_info = scoring(coords, pocket_coords, distance_predict, holo_distance_predict) |
|
|
131 |
return coords.detach().numpy(), loss.detach().numpy(), meta_info |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
def set_coord(mol, coords): |
|
|
135 |
for i in range(coords.shape[0]): |
|
|
136 |
mol.GetConformer(0).SetAtomPosition(i, coords[i].tolist()) |
|
|
137 |
return mol |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
def add_coord(mol, xyz): |
|
|
141 |
x, y, z = xyz |
|
|
142 |
conf = mol.GetConformer(0) |
|
|
143 |
pos = conf.GetPositions() |
|
|
144 |
pos[:, 0] += x |
|
|
145 |
pos[:, 1] += y |
|
|
146 |
pos[:, 2] += z |
|
|
147 |
for i in range(pos.shape[0]): |
|
|
148 |
conf.SetAtomPosition( |
|
|
149 |
i, Chem.rdGeometry.Point3D(pos[i][0], pos[i][1], pos[i][2]) |
|
|
150 |
) |
|
|
151 |
return mol |
|
|
152 |
|
|
|
153 |
|
|
|
154 |
def single_docking(input_path, output_path, output_ligand_path): |
|
|
155 |
content = pd.read_pickle(input_path) |
|
|
156 |
( |
|
|
157 |
init_coords_tta, |
|
|
158 |
mol, |
|
|
159 |
smi, |
|
|
160 |
pocket, |
|
|
161 |
pocket_coords, |
|
|
162 |
distance_predict_tta, |
|
|
163 |
holo_distance_predict_tta, |
|
|
164 |
holo_coords, |
|
|
165 |
holo_cener_coords, |
|
|
166 |
) = content |
|
|
167 |
sample_times = len(init_coords_tta) |
|
|
168 |
bst_predict_coords, bst_loss, bst_meta_info = None, 1000.0, None |
|
|
169 |
for i in range(sample_times): |
|
|
170 |
init_coords = init_coords_tta[i] |
|
|
171 |
predict_coords, loss, meta_info = dock_with_gradient( |
|
|
172 |
init_coords, |
|
|
173 |
pocket_coords, |
|
|
174 |
distance_predict_tta, |
|
|
175 |
holo_distance_predict_tta, |
|
|
176 |
holo_coords=holo_coords, |
|
|
177 |
loss_func=single_SF_loss, |
|
|
178 |
) |
|
|
179 |
if loss < bst_loss: |
|
|
180 |
bst_loss = loss |
|
|
181 |
bst_predict_coords = predict_coords |
|
|
182 |
bst_meta_info = meta_info |
|
|
183 |
|
|
|
184 |
_rmsd = round(rmsd_func(holo_coords, bst_predict_coords), 4) |
|
|
185 |
_cross_score = round(float(bst_meta_info[0]), 4) |
|
|
186 |
_self_score = round(float(bst_meta_info[1]), 4) |
|
|
187 |
print(f"{pocket}-{smi}-RMSD:{_rmsd}-{_cross_score}-{_self_score}") |
|
|
188 |
mol = Chem.RemoveHs(mol) |
|
|
189 |
mol = set_coord(mol, bst_predict_coords) |
|
|
190 |
|
|
|
191 |
if output_path is not None: |
|
|
192 |
with open(output_path, "wb") as f: |
|
|
193 |
pickle.dump( |
|
|
194 |
[bst_predict_coords, holo_coords, bst_loss, smi, pocket, pocket_coords], |
|
|
195 |
f, |
|
|
196 |
) |
|
|
197 |
if output_ligand_path is not None: |
|
|
198 |
mol = add_coord(mol, holo_cener_coords.numpy()) |
|
|
199 |
Chem.MolToMolFile(mol, output_ligand_path) |
|
|
200 |
|
|
|
201 |
return True |
|
|
202 |
|
|
|
203 |
|
|
|
204 |
if __name__ == "__main__": |
|
|
205 |
torch.set_num_threads(1) |
|
|
206 |
torch.manual_seed(0) |
|
|
207 |
parser = argparse.ArgumentParser(description="Docking with gradient") |
|
|
208 |
parser.add_argument("--input", type=str, help="input file.") |
|
|
209 |
parser.add_argument("--output", type=str, default=None, help="output path.") |
|
|
210 |
parser.add_argument( |
|
|
211 |
"--output-ligand", type=str, default=None, help="output ligand sdf path." |
|
|
212 |
) |
|
|
213 |
args = parser.parse_args() |
|
|
214 |
|
|
|
215 |
single_docking(args.input, args.output, args.output_ligand) |