Switch to unified view

a b/unimol/utils/docking_utils.py
1
# Copyright (c) DP Techonology, Inc. and its affiliates.
2
#
3
# This source code is licensed under the MIT license found in the
4
# LICENSE file in the root directory of this source tree.
5
6
import numpy as np
7
from rdkit import Chem
8
from rdkit.Chem import AllChem
9
from rdkit import RDLogger
10
11
RDLogger.DisableLog("rdApp.*")
12
import warnings
13
14
warnings.filterwarnings(action="ignore")
15
from rdkit.Chem import rdMolTransforms
16
import copy
17
import lmdb
18
import pickle
19
import pandas as pd
20
21
22
def get_torsions(m, removeHs=True):
23
    if removeHs:
24
        m = Chem.RemoveHs(m)
25
    torsionList = []
26
    torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"
27
    torsionQuery = Chem.MolFromSmarts(torsionSmarts)
28
    matches = m.GetSubstructMatches(torsionQuery)
29
    for match in matches:
30
        idx2 = match[0]
31
        idx3 = match[1]
32
        bond = m.GetBondBetweenAtoms(idx2, idx3)
33
        jAtom = m.GetAtomWithIdx(idx2)
34
        kAtom = m.GetAtomWithIdx(idx3)
35
        for b1 in jAtom.GetBonds():
36
            if b1.GetIdx() == bond.GetIdx():
37
                continue
38
            idx1 = b1.GetOtherAtomIdx(idx2)
39
            for b2 in kAtom.GetBonds():
40
                if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()):
41
                    continue
42
                idx4 = b2.GetOtherAtomIdx(idx3)
43
                # skip 3-membered rings
44
                if idx4 == idx1:
45
                    continue
46
                # skip torsions that include hydrogens
47
                if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or (
48
                    m.GetAtomWithIdx(idx4).GetAtomicNum() == 1
49
                ):
50
                    continue
51
                if m.GetAtomWithIdx(idx4).IsInRing():
52
                    torsionList.append((idx4, idx3, idx2, idx1))
53
                    break
54
                else:
55
                    torsionList.append((idx1, idx2, idx3, idx4))
56
                    break
57
            break
58
    return torsionList
59
60
61
def SetDihedral(conf, atom_idx, new_vale):
62
    rdMolTransforms.SetDihedralRad(
63
        conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale
64
    )
65
66
67
def single_conf_gen_bonds(tgt_mol, num_confs=1000, seed=42, removeHs=True):
68
    mol = copy.deepcopy(tgt_mol)
69
    mol = Chem.AddHs(mol)
70
    allconformers = AllChem.EmbedMultipleConfs(
71
        mol, numConfs=num_confs, randomSeed=seed, clearConfs=True
72
    )
73
    if removeHs:
74
        mol = Chem.RemoveHs(mol)
75
    rotable_bonds = get_torsions(mol, removeHs=removeHs)
76
    for i in range(len(allconformers)):
77
        np.random.seed(i)
78
        values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds))
79
        for idx in range(len(rotable_bonds)):
80
            SetDihedral(mol.GetConformers()[i], rotable_bonds[idx], values[idx])
81
        Chem.rdMolTransforms.CanonicalizeConformer(mol.GetConformers()[i])
82
    return mol
83
84
85
def load_lmdb_data(lmdb_path, key):
86
    env = lmdb.open(
87
        lmdb_path,
88
        subdir=False,
89
        readonly=True,
90
        lock=False,
91
        readahead=False,
92
        meminit=False,
93
        max_readers=256,
94
    )
95
    txn = env.begin()
96
    _keys = list(txn.cursor().iternext(values=False))
97
    collects = []
98
    for idx in range(len(_keys)):
99
        datapoint_pickled = txn.get(f"{idx}".encode("ascii"))
100
        data = pickle.loads(datapoint_pickled)
101
        collects.append(data[key])
102
    return collects
103
104
105
def docking_data_pre(raw_data_path, predict_path):
106
107
    mol_list = load_lmdb_data(raw_data_path, "mol_list")
108
    mol_list = [Chem.RemoveHs(mol) for items in mol_list for mol in items]
109
    predict = pd.read_pickle(predict_path)
110
    (
111
        smi_list,
112
        pocket_list,
113
        pocket_coords_list,
114
        distance_predict_list,
115
        holo_distance_predict_list,
116
        holo_coords_list,
117
        holo_center_coords_list,
118
    ) = ([], [], [], [], [], [], [])
119
    for batch in predict:
120
        sz = batch["atoms"].size(0)
121
        for i in range(sz):
122
            smi_list.append(batch["smi_name"][i])
123
            pocket_list.append(batch["pocket_name"][i])
124
125
            distance_predict = batch["cross_distance_predict"][i]
126
            token_mask = batch["atoms"][i] > 2
127
            pocket_token_mask = batch["pocket_atoms"][i] > 2
128
            distance_predict = distance_predict[token_mask][:, pocket_token_mask]
129
            pocket_coords = batch["pocket_coordinates"][i]
130
            pocket_coords = pocket_coords[pocket_token_mask, :]
131
132
            holo_distance_predict = batch["holo_distance_predict"][i]
133
            holo_distance_predict = holo_distance_predict[token_mask][:, token_mask]
134
135
            holo_coordinates = batch["holo_coordinates"][i]
136
            holo_coordinates = holo_coordinates[token_mask, :]
137
            holo_center_coordinates = batch["holo_center_coordinates"][i][:3]
138
139
            pocket_coords = pocket_coords.numpy().astype(np.float32)
140
            distance_predict = distance_predict.numpy().astype(np.float32)
141
            holo_distance_predict = holo_distance_predict.numpy().astype(np.float32)
142
            holo_coords = holo_coordinates.numpy().astype(np.float32)
143
144
            pocket_coords_list.append(pocket_coords)
145
            distance_predict_list.append(distance_predict)
146
            holo_distance_predict_list.append(holo_distance_predict)
147
            holo_coords_list.append(holo_coords)
148
            holo_center_coords_list.append(holo_center_coordinates)
149
150
    return (
151
        mol_list,
152
        smi_list,
153
        pocket_list,
154
        pocket_coords_list,
155
        distance_predict_list,
156
        holo_distance_predict_list,
157
        holo_coords_list,
158
        holo_center_coords_list,
159
    )
160
161
162
def ensemble_iterations(
163
    mol_list,
164
    smi_list,
165
    pocket_list,
166
    pocket_coords_list,
167
    distance_predict_list,
168
    holo_distance_predict_list,
169
    holo_coords_list,
170
    holo_center_coords_list,
171
    tta_times=10,
172
):
173
    sz = len(mol_list)
174
    for i in range(sz // tta_times):
175
        start_idx, end_idx = i * tta_times, (i + 1) * tta_times
176
        distance_predict_tta = distance_predict_list[start_idx:end_idx]
177
        holo_distance_predict_tta = holo_distance_predict_list[start_idx:end_idx]
178
179
        mol = copy.deepcopy(mol_list[start_idx])
180
        rdkit_mol = single_conf_gen_bonds(
181
            mol, num_confs=tta_times, seed=42, removeHs=True
182
        )
183
        sz = len(rdkit_mol.GetConformers())
184
        initial_coords_list = [
185
            rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)
186
            for i in range(sz)
187
        ]
188
189
        yield [
190
            initial_coords_list,
191
            mol,
192
            smi_list[start_idx],
193
            pocket_list[start_idx],
194
            pocket_coords_list[start_idx],
195
            distance_predict_tta,
196
            holo_distance_predict_tta,
197
            holo_coords_list[start_idx],
198
            holo_center_coords_list[start_idx],
199
        ]
200
201
202
def rmsd_func(holo_coords, predict_coords):
203
    if predict_coords is not np.nan:
204
        sz = holo_coords.shape
205
        rmsd = np.sqrt(np.sum((predict_coords - holo_coords) ** 2) / sz[0])
206
        return rmsd
207
    return 1000.0
208
209
210
def print_results(rmsd_results):
211
    print("RMSD < 1.0 : ", np.mean(rmsd_results < 1.0))
212
    print("RMSD < 1.5 : ", np.mean(rmsd_results < 1.5))
213
    print("RMSD < 2.0 : ", np.mean(rmsd_results < 2.0))
214
    print("RMSD < 3.0 : ", np.mean(rmsd_results < 3.0))
215
    print("RMSD < 5.0 : ", np.mean(rmsd_results < 5.0))
216
    print("avg RMSD : ", np.mean(rmsd_results))