|
a |
|
b/analysis/molecule_builder.py |
|
|
1 |
import warnings |
|
|
2 |
import tempfile |
|
|
3 |
|
|
|
4 |
import torch |
|
|
5 |
import numpy as np |
|
|
6 |
from rdkit import Chem |
|
|
7 |
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule, UFFHasAllMoleculeParams |
|
|
8 |
import openbabel |
|
|
9 |
|
|
|
10 |
import utils |
|
|
11 |
from constants import bonds1, bonds2, bonds3, margin1, margin2, margin3, \ |
|
|
12 |
bond_dict |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def get_bond_order(atom1, atom2, distance): |
|
|
16 |
distance = 100 * distance # We change the metric |
|
|
17 |
|
|
|
18 |
if atom1 in bonds3 and atom2 in bonds3[atom1] and distance < bonds3[atom1][atom2] + margin3: |
|
|
19 |
return 3 # Triple |
|
|
20 |
|
|
|
21 |
if atom1 in bonds2 and atom2 in bonds2[atom1] and distance < bonds2[atom1][atom2] + margin2: |
|
|
22 |
return 2 # Double |
|
|
23 |
|
|
|
24 |
if atom1 in bonds1 and atom2 in bonds1[atom1] and distance < bonds1[atom1][atom2] + margin1: |
|
|
25 |
return 1 # Single |
|
|
26 |
|
|
|
27 |
return 0 # No bond |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def get_bond_order_batch(atoms1, atoms2, distances, dataset_info): |
|
|
31 |
if isinstance(atoms1, np.ndarray): |
|
|
32 |
atoms1 = torch.from_numpy(atoms1) |
|
|
33 |
if isinstance(atoms2, np.ndarray): |
|
|
34 |
atoms2 = torch.from_numpy(atoms2) |
|
|
35 |
if isinstance(distances, np.ndarray): |
|
|
36 |
distances = torch.from_numpy(distances) |
|
|
37 |
|
|
|
38 |
distances = 100 * distances # We change the metric |
|
|
39 |
|
|
|
40 |
bonds1 = torch.tensor(dataset_info['bonds1'], device=atoms1.device) |
|
|
41 |
bonds2 = torch.tensor(dataset_info['bonds2'], device=atoms1.device) |
|
|
42 |
bonds3 = torch.tensor(dataset_info['bonds3'], device=atoms1.device) |
|
|
43 |
|
|
|
44 |
bond_types = torch.zeros_like(atoms1) # 0: No bond |
|
|
45 |
|
|
|
46 |
# Single |
|
|
47 |
bond_types[distances < bonds1[atoms1, atoms2] + margin1] = 1 |
|
|
48 |
|
|
|
49 |
# Double (note that already assigned single bonds will be overwritten) |
|
|
50 |
bond_types[distances < bonds2[atoms1, atoms2] + margin2] = 2 |
|
|
51 |
|
|
|
52 |
# Triple |
|
|
53 |
bond_types[distances < bonds3[atoms1, atoms2] + margin3] = 3 |
|
|
54 |
|
|
|
55 |
return bond_types |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def make_mol_openbabel(positions, atom_types, atom_decoder): |
|
|
59 |
""" |
|
|
60 |
Build an RDKit molecule using openbabel for creating bonds |
|
|
61 |
Args: |
|
|
62 |
positions: N x 3 |
|
|
63 |
atom_types: N |
|
|
64 |
atom_decoder: maps indices to atom types |
|
|
65 |
Returns: |
|
|
66 |
rdkit molecule |
|
|
67 |
""" |
|
|
68 |
atom_types = [atom_decoder[x] for x in atom_types] |
|
|
69 |
|
|
|
70 |
with tempfile.NamedTemporaryFile() as tmp: |
|
|
71 |
tmp_file = tmp.name |
|
|
72 |
|
|
|
73 |
# Write xyz file |
|
|
74 |
utils.write_xyz_file(positions, atom_types, tmp_file) |
|
|
75 |
|
|
|
76 |
# Convert to sdf file with openbabel |
|
|
77 |
# openbabel will add bonds |
|
|
78 |
obConversion = openbabel.OBConversion() |
|
|
79 |
obConversion.SetInAndOutFormats("xyz", "sdf") |
|
|
80 |
ob_mol = openbabel.OBMol() |
|
|
81 |
obConversion.ReadFile(ob_mol, tmp_file) |
|
|
82 |
|
|
|
83 |
obConversion.WriteFile(ob_mol, tmp_file) |
|
|
84 |
|
|
|
85 |
# Read sdf file with RDKit |
|
|
86 |
tmp_mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0] |
|
|
87 |
|
|
|
88 |
# Build new molecule. This is a workaround to remove radicals. |
|
|
89 |
mol = Chem.RWMol() |
|
|
90 |
for atom in tmp_mol.GetAtoms(): |
|
|
91 |
mol.AddAtom(Chem.Atom(atom.GetSymbol())) |
|
|
92 |
mol.AddConformer(tmp_mol.GetConformer(0)) |
|
|
93 |
|
|
|
94 |
for bond in tmp_mol.GetBonds(): |
|
|
95 |
mol.AddBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), |
|
|
96 |
bond.GetBondType()) |
|
|
97 |
|
|
|
98 |
return mol |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
def make_mol_edm(positions, atom_types, dataset_info, add_coords): |
|
|
102 |
""" |
|
|
103 |
Equivalent to EDM's way of building RDKit molecules |
|
|
104 |
""" |
|
|
105 |
n = len(positions) |
|
|
106 |
|
|
|
107 |
# (X, A, E): atom_types, adjacency matrix, edge_types |
|
|
108 |
# X: N (int) |
|
|
109 |
# A: N x N (bool) -> (binary adjacency matrix) |
|
|
110 |
# E: N x N (int) -> (bond type, 0 if no bond) |
|
|
111 |
pos = positions.unsqueeze(0) # add batch dim |
|
|
112 |
dists = torch.cdist(pos, pos, p=2).squeeze(0).view(-1) # remove batch dim & flatten |
|
|
113 |
atoms1, atoms2 = torch.cartesian_prod(atom_types, atom_types).T |
|
|
114 |
E_full = get_bond_order_batch(atoms1, atoms2, dists, dataset_info).view(n, n) |
|
|
115 |
E = torch.tril(E_full, diagonal=-1) # Warning: the graph should be DIRECTED |
|
|
116 |
A = E.bool() |
|
|
117 |
X = atom_types |
|
|
118 |
|
|
|
119 |
mol = Chem.RWMol() |
|
|
120 |
for atom in X: |
|
|
121 |
a = Chem.Atom(dataset_info["atom_decoder"][atom.item()]) |
|
|
122 |
mol.AddAtom(a) |
|
|
123 |
|
|
|
124 |
all_bonds = torch.nonzero(A) |
|
|
125 |
for bond in all_bonds: |
|
|
126 |
mol.AddBond(bond[0].item(), bond[1].item(), |
|
|
127 |
bond_dict[E[bond[0], bond[1]].item()]) |
|
|
128 |
|
|
|
129 |
if add_coords: |
|
|
130 |
conf = Chem.Conformer(mol.GetNumAtoms()) |
|
|
131 |
for i in range(mol.GetNumAtoms()): |
|
|
132 |
conf.SetAtomPosition(i, (positions[i, 0].item(), |
|
|
133 |
positions[i, 1].item(), |
|
|
134 |
positions[i, 2].item())) |
|
|
135 |
mol.AddConformer(conf) |
|
|
136 |
|
|
|
137 |
return mol |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
def build_molecule(positions, atom_types, dataset_info, add_coords=False, |
|
|
141 |
use_openbabel=True): |
|
|
142 |
""" |
|
|
143 |
Build RDKit molecule |
|
|
144 |
Args: |
|
|
145 |
positions: N x 3 |
|
|
146 |
atom_types: N |
|
|
147 |
dataset_info: dict |
|
|
148 |
add_coords: Add conformer to mol (always added if use_openbabel=True) |
|
|
149 |
use_openbabel: use OpenBabel to create bonds |
|
|
150 |
Returns: |
|
|
151 |
RDKit molecule |
|
|
152 |
""" |
|
|
153 |
if use_openbabel: |
|
|
154 |
mol = make_mol_openbabel(positions, atom_types, |
|
|
155 |
dataset_info["atom_decoder"]) |
|
|
156 |
else: |
|
|
157 |
mol = make_mol_edm(positions, atom_types, dataset_info, add_coords) |
|
|
158 |
|
|
|
159 |
return mol |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def process_molecule(rdmol, add_hydrogens=False, sanitize=False, relax_iter=0, |
|
|
163 |
largest_frag=False): |
|
|
164 |
""" |
|
|
165 |
Apply filters to an RDKit molecule. Makes a copy first. |
|
|
166 |
Args: |
|
|
167 |
rdmol: rdkit molecule |
|
|
168 |
add_hydrogens |
|
|
169 |
sanitize |
|
|
170 |
relax_iter: maximum number of UFF optimization iterations |
|
|
171 |
largest_frag: filter out the largest fragment in a set of disjoint |
|
|
172 |
molecules |
|
|
173 |
Returns: |
|
|
174 |
RDKit molecule or None if it does not pass the filters |
|
|
175 |
""" |
|
|
176 |
|
|
|
177 |
# Create a copy |
|
|
178 |
mol = Chem.Mol(rdmol) |
|
|
179 |
|
|
|
180 |
if sanitize: |
|
|
181 |
try: |
|
|
182 |
Chem.SanitizeMol(mol) |
|
|
183 |
except ValueError: |
|
|
184 |
warnings.warn('Sanitization failed. Returning None.') |
|
|
185 |
return None |
|
|
186 |
|
|
|
187 |
if add_hydrogens: |
|
|
188 |
mol = Chem.AddHs(mol, addCoords=(len(mol.GetConformers()) > 0)) |
|
|
189 |
|
|
|
190 |
if largest_frag: |
|
|
191 |
mol_frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) |
|
|
192 |
mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) |
|
|
193 |
if sanitize: |
|
|
194 |
# sanitize the updated molecule |
|
|
195 |
try: |
|
|
196 |
Chem.SanitizeMol(mol) |
|
|
197 |
except ValueError: |
|
|
198 |
return None |
|
|
199 |
|
|
|
200 |
if relax_iter > 0: |
|
|
201 |
if not UFFHasAllMoleculeParams(mol): |
|
|
202 |
warnings.warn('UFF parameters not available for all atoms. ' |
|
|
203 |
'Returning None.') |
|
|
204 |
return None |
|
|
205 |
|
|
|
206 |
try: |
|
|
207 |
uff_relax(mol, relax_iter) |
|
|
208 |
if sanitize: |
|
|
209 |
# sanitize the updated molecule |
|
|
210 |
Chem.SanitizeMol(mol) |
|
|
211 |
except (RuntimeError, ValueError) as e: |
|
|
212 |
return None |
|
|
213 |
|
|
|
214 |
return mol |
|
|
215 |
|
|
|
216 |
|
|
|
217 |
def uff_relax(mol, max_iter=200): |
|
|
218 |
""" |
|
|
219 |
Uses RDKit's universal force field (UFF) implementation to optimize a |
|
|
220 |
molecule. |
|
|
221 |
""" |
|
|
222 |
more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter) |
|
|
223 |
if more_iterations_required: |
|
|
224 |
warnings.warn(f'Maximum number of FF iterations reached. ' |
|
|
225 |
f'Returning molecule after {max_iter} relaxation steps.') |
|
|
226 |
return more_iterations_required |
|
|
227 |
|
|
|
228 |
|
|
|
229 |
def filter_rd_mol(rdmol): |
|
|
230 |
""" |
|
|
231 |
Filter out RDMols if they have a 3-3 ring intersection |
|
|
232 |
adapted from: |
|
|
233 |
https://github.com/luost26/3D-Generative-SBDD/blob/main/utils/chem.py |
|
|
234 |
""" |
|
|
235 |
ring_info = rdmol.GetRingInfo() |
|
|
236 |
ring_info.AtomRings() |
|
|
237 |
rings = [set(r) for r in ring_info.AtomRings()] |
|
|
238 |
|
|
|
239 |
# 3-3 ring intersection |
|
|
240 |
for i, ring_a in enumerate(rings): |
|
|
241 |
if len(ring_a) != 3: |
|
|
242 |
continue |
|
|
243 |
for j, ring_b in enumerate(rings): |
|
|
244 |
if i <= j: |
|
|
245 |
continue |
|
|
246 |
inter = ring_a.intersection(ring_b) |
|
|
247 |
if (len(ring_b) == 3) and (len(inter) > 0): |
|
|
248 |
return False |
|
|
249 |
|
|
|
250 |
return True |