a b/unimol/utils/coordinate_model.py
1
# Copyright (c) DP Techonology, Inc. and its affiliates.
2
#
3
# This source code is licensed under the MIT license found in the
4
# LICENSE file in the root directory of this source tree.
5
6
import copy
7
import torch
8
import pandas as pd
9
from rdkit import Chem
10
import pickle
11
import argparse
12
from docking_utils import rmsd_func
13
import warnings
14
15
warnings.filterwarnings(action="ignore")
16
17
18
def single_SF_loss(
19
    predict_coords,
20
    pocket_coords,
21
    distance_predict,
22
    holo_distance_predict,
23
    dist_threshold=4.5,
24
):
25
    dist = torch.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1)
26
    holo_dist = torch.norm(
27
        predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1
28
    )
29
    distance_mask = distance_predict < dist_threshold
30
    cross_dist_score = (
31
        (dist[distance_mask] - distance_predict[distance_mask]) ** 2
32
    ).mean()
33
    dist_score = ((holo_dist - holo_distance_predict) ** 2).mean()
34
    loss = cross_dist_score * 1.0 + dist_score * 5.0
35
    return loss
36
37
38
def scoring(
39
    predict_coords,
40
    pocket_coords,
41
    distance_predict,
42
    holo_distance_predict,
43
    dist_threshold=4.5,
44
):
45
    predict_coords = predict_coords.detach()
46
    dist = torch.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1)
47
    holo_dist = torch.norm(
48
        predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1
49
    )
50
    distance_mask = distance_predict < dist_threshold
51
    cross_dist_score = (
52
        (dist[distance_mask] - distance_predict[distance_mask]) ** 2
53
    ).mean()
54
    dist_score = ((holo_dist - holo_distance_predict) ** 2).mean()
55
    return cross_dist_score.numpy(), dist_score.numpy()
56
57
58
def dock_with_gradient(
59
    coords,
60
    pocket_coords,
61
    distance_predict_tta,
62
    holo_distance_predict_tta,
63
    loss_func=single_SF_loss,
64
    holo_coords=None,
65
    iterations=20000,
66
    early_stoping=5,
67
):
68
    bst_loss, bst_coords, bst_meta_info = 10000.0, coords, None
69
    for i, (distance_predict, holo_distance_predict) in enumerate(
70
        zip(distance_predict_tta, holo_distance_predict_tta)
71
    ):
72
        new_coords = copy.deepcopy(coords)
73
        _coords, _loss, _meta_info = single_dock_with_gradient(
74
            new_coords,
75
            pocket_coords,
76
            distance_predict,
77
            holo_distance_predict,
78
            loss_func=loss_func,
79
            holo_coords=holo_coords,
80
            iterations=iterations,
81
            early_stoping=early_stoping,
82
        )
83
        if bst_loss > _loss:
84
            bst_coords = _coords
85
            bst_loss = _loss
86
            bst_meta_info = _meta_info
87
    return bst_coords, bst_loss, bst_meta_info
88
89
90
def single_dock_with_gradient(
91
    coords,
92
    pocket_coords,
93
    distance_predict,
94
    holo_distance_predict,
95
    loss_func=single_SF_loss,
96
    holo_coords=None,
97
    iterations=20000,
98
    early_stoping=5,
99
):
100
    coords = torch.from_numpy(coords).float()
101
    pocket_coords = torch.from_numpy(pocket_coords).float()
102
    distance_predict = torch.from_numpy(distance_predict).float()
103
    holo_distance_predict = torch.from_numpy(holo_distance_predict).float()
104
105
    if holo_coords is not None:
106
        holo_coords = torch.from_numpy(holo_coords).float()
107
108
    coords.requires_grad = True
109
    optimizer = torch.optim.LBFGS([coords], lr=1.0)
110
    bst_loss, times = 10000.0, 0
111
    for i in range(iterations):
112
113
        def closure():
114
            optimizer.zero_grad()
115
            loss = loss_func(
116
                coords, pocket_coords, distance_predict, holo_distance_predict
117
            )
118
            loss.backward()
119
            return loss
120
121
        loss = optimizer.step(closure)
122
        if loss.item() < bst_loss:
123
            bst_loss = loss.item()
124
            times = 0
125
        else:
126
            times += 1
127
            if times > early_stoping:
128
                break
129
130
    meta_info = scoring(coords, pocket_coords, distance_predict, holo_distance_predict)
131
    return coords.detach().numpy(), loss.detach().numpy(), meta_info
132
133
134
def set_coord(mol, coords):
135
    for i in range(coords.shape[0]):
136
        mol.GetConformer(0).SetAtomPosition(i, coords[i].tolist())
137
    return mol
138
139
140
def add_coord(mol, xyz):
141
    x, y, z = xyz
142
    conf = mol.GetConformer(0)
143
    pos = conf.GetPositions()
144
    pos[:, 0] += x
145
    pos[:, 1] += y
146
    pos[:, 2] += z
147
    for i in range(pos.shape[0]):
148
        conf.SetAtomPosition(
149
            i, Chem.rdGeometry.Point3D(pos[i][0], pos[i][1], pos[i][2])
150
        )
151
    return mol
152
153
154
def single_docking(input_path, output_path, output_ligand_path):
155
    content = pd.read_pickle(input_path)
156
    (
157
        init_coords_tta,
158
        mol,
159
        smi,
160
        pocket,
161
        pocket_coords,
162
        distance_predict_tta,
163
        holo_distance_predict_tta,
164
        holo_coords,
165
        holo_cener_coords,
166
    ) = content
167
    sample_times = len(init_coords_tta)
168
    bst_predict_coords, bst_loss, bst_meta_info = None, 1000.0, None
169
    for i in range(sample_times):
170
        init_coords = init_coords_tta[i]
171
        predict_coords, loss, meta_info = dock_with_gradient(
172
            init_coords,
173
            pocket_coords,
174
            distance_predict_tta,
175
            holo_distance_predict_tta,
176
            holo_coords=holo_coords,
177
            loss_func=single_SF_loss,
178
        )
179
        if loss < bst_loss:
180
            bst_loss = loss
181
            bst_predict_coords = predict_coords
182
            bst_meta_info = meta_info
183
184
    _rmsd = round(rmsd_func(holo_coords, bst_predict_coords), 4)
185
    _cross_score = round(float(bst_meta_info[0]), 4)
186
    _self_score = round(float(bst_meta_info[1]), 4)
187
    print(f"{pocket}-{smi}-RMSD:{_rmsd}-{_cross_score}-{_self_score}")
188
    mol = Chem.RemoveHs(mol)
189
    mol = set_coord(mol, bst_predict_coords)
190
191
    if output_path is not None:
192
        with open(output_path, "wb") as f:
193
            pickle.dump(
194
                [bst_predict_coords, holo_coords, bst_loss, smi, pocket, pocket_coords],
195
                f,
196
            )
197
    if output_ligand_path is not None:
198
        mol = add_coord(mol, holo_cener_coords.numpy())
199
        Chem.MolToMolFile(mol, output_ligand_path)
200
201
    return True
202
203
204
if __name__ == "__main__":
205
    torch.set_num_threads(1)
206
    torch.manual_seed(0)
207
    parser = argparse.ArgumentParser(description="Docking with gradient")
208
    parser.add_argument("--input", type=str, help="input file.")
209
    parser.add_argument("--output", type=str, default=None, help="output path.")
210
    parser.add_argument(
211
        "--output-ligand", type=str, default=None, help="output ligand sdf path."
212
    )
213
    args = parser.parse_args()
214
215
    single_docking(args.input, args.output, args.output_ligand)