Switch to unified view

a b/analysis/molecule_builder.py
1
import warnings
2
import tempfile
3
4
import torch
5
import numpy as np
6
from rdkit import Chem
7
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule, UFFHasAllMoleculeParams
8
import openbabel
9
10
import utils
11
from constants import bonds1, bonds2, bonds3, margin1, margin2, margin3, \
12
    bond_dict
13
14
15
def get_bond_order(atom1, atom2, distance):
16
    distance = 100 * distance  # We change the metric
17
18
    if atom1 in bonds3 and atom2 in bonds3[atom1] and distance < bonds3[atom1][atom2] + margin3:
19
        return 3  # Triple
20
21
    if atom1 in bonds2 and atom2 in bonds2[atom1] and distance < bonds2[atom1][atom2] + margin2:
22
        return 2  # Double
23
24
    if atom1 in bonds1 and atom2 in bonds1[atom1] and distance < bonds1[atom1][atom2] + margin1:
25
        return 1  # Single
26
27
    return 0      # No bond
28
29
30
def get_bond_order_batch(atoms1, atoms2, distances, dataset_info):
31
    if isinstance(atoms1, np.ndarray):
32
        atoms1 = torch.from_numpy(atoms1)
33
    if isinstance(atoms2, np.ndarray):
34
        atoms2 = torch.from_numpy(atoms2)
35
    if isinstance(distances, np.ndarray):
36
        distances = torch.from_numpy(distances)
37
38
    distances = 100 * distances  # We change the metric
39
40
    bonds1 = torch.tensor(dataset_info['bonds1'], device=atoms1.device)
41
    bonds2 = torch.tensor(dataset_info['bonds2'], device=atoms1.device)
42
    bonds3 = torch.tensor(dataset_info['bonds3'], device=atoms1.device)
43
44
    bond_types = torch.zeros_like(atoms1)  # 0: No bond
45
46
    # Single
47
    bond_types[distances < bonds1[atoms1, atoms2] + margin1] = 1
48
49
    # Double (note that already assigned single bonds will be overwritten)
50
    bond_types[distances < bonds2[atoms1, atoms2] + margin2] = 2
51
52
    # Triple
53
    bond_types[distances < bonds3[atoms1, atoms2] + margin3] = 3
54
55
    return bond_types
56
57
58
def make_mol_openbabel(positions, atom_types, atom_decoder):
59
    """
60
    Build an RDKit molecule using openbabel for creating bonds
61
    Args:
62
        positions: N x 3
63
        atom_types: N
64
        atom_decoder: maps indices to atom types
65
    Returns:
66
        rdkit molecule
67
    """
68
    atom_types = [atom_decoder[x] for x in atom_types]
69
70
    with tempfile.NamedTemporaryFile() as tmp:
71
        tmp_file = tmp.name
72
73
        # Write xyz file
74
        utils.write_xyz_file(positions, atom_types, tmp_file)
75
76
        # Convert to sdf file with openbabel
77
        # openbabel will add bonds
78
        obConversion = openbabel.OBConversion()
79
        obConversion.SetInAndOutFormats("xyz", "sdf")
80
        ob_mol = openbabel.OBMol()
81
        obConversion.ReadFile(ob_mol, tmp_file)
82
83
        obConversion.WriteFile(ob_mol, tmp_file)
84
85
        # Read sdf file with RDKit
86
        tmp_mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0]
87
88
    # Build new molecule. This is a workaround to remove radicals.
89
    mol = Chem.RWMol()
90
    for atom in tmp_mol.GetAtoms():
91
        mol.AddAtom(Chem.Atom(atom.GetSymbol()))
92
    mol.AddConformer(tmp_mol.GetConformer(0))
93
94
    for bond in tmp_mol.GetBonds():
95
        mol.AddBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(),
96
                    bond.GetBondType())
97
98
    return mol
99
100
101
def make_mol_edm(positions, atom_types, dataset_info, add_coords):
102
    """
103
    Equivalent to EDM's way of building RDKit molecules
104
    """
105
    n = len(positions)
106
107
    # (X, A, E): atom_types, adjacency matrix, edge_types
108
    # X: N (int)
109
    # A: N x N (bool) -> (binary adjacency matrix)
110
    # E: N x N (int) -> (bond type, 0 if no bond)
111
    pos = positions.unsqueeze(0)  # add batch dim
112
    dists = torch.cdist(pos, pos, p=2).squeeze(0).view(-1)  # remove batch dim & flatten
113
    atoms1, atoms2 = torch.cartesian_prod(atom_types, atom_types).T
114
    E_full = get_bond_order_batch(atoms1, atoms2, dists, dataset_info).view(n, n)
115
    E = torch.tril(E_full, diagonal=-1)  # Warning: the graph should be DIRECTED
116
    A = E.bool()
117
    X = atom_types
118
119
    mol = Chem.RWMol()
120
    for atom in X:
121
        a = Chem.Atom(dataset_info["atom_decoder"][atom.item()])
122
        mol.AddAtom(a)
123
124
    all_bonds = torch.nonzero(A)
125
    for bond in all_bonds:
126
        mol.AddBond(bond[0].item(), bond[1].item(),
127
                    bond_dict[E[bond[0], bond[1]].item()])
128
129
    if add_coords:
130
        conf = Chem.Conformer(mol.GetNumAtoms())
131
        for i in range(mol.GetNumAtoms()):
132
            conf.SetAtomPosition(i, (positions[i, 0].item(),
133
                                     positions[i, 1].item(),
134
                                     positions[i, 2].item()))
135
        mol.AddConformer(conf)
136
137
    return mol
138
139
140
def build_molecule(positions, atom_types, dataset_info, add_coords=False,
141
                   use_openbabel=True):
142
    """
143
    Build RDKit molecule
144
    Args:
145
        positions: N x 3
146
        atom_types: N
147
        dataset_info: dict
148
        add_coords: Add conformer to mol (always added if use_openbabel=True)
149
        use_openbabel: use OpenBabel to create bonds
150
    Returns:
151
        RDKit molecule
152
    """
153
    if use_openbabel:
154
        mol = make_mol_openbabel(positions, atom_types,
155
                                 dataset_info["atom_decoder"])
156
    else:
157
        mol = make_mol_edm(positions, atom_types, dataset_info, add_coords)
158
159
    return mol
160
161
162
def process_molecule(rdmol, add_hydrogens=False, sanitize=False, relax_iter=0,
163
                     largest_frag=False):
164
    """
165
    Apply filters to an RDKit molecule. Makes a copy first.
166
    Args:
167
        rdmol: rdkit molecule
168
        add_hydrogens
169
        sanitize
170
        relax_iter: maximum number of UFF optimization iterations
171
        largest_frag: filter out the largest fragment in a set of disjoint
172
            molecules
173
    Returns:
174
        RDKit molecule or None if it does not pass the filters
175
    """
176
177
    # Create a copy
178
    mol = Chem.Mol(rdmol)
179
180
    if sanitize:
181
        try:
182
            Chem.SanitizeMol(mol)
183
        except ValueError:
184
            warnings.warn('Sanitization failed. Returning None.')
185
            return None
186
187
    if add_hydrogens:
188
        mol = Chem.AddHs(mol, addCoords=(len(mol.GetConformers()) > 0))
189
190
    if largest_frag:
191
        mol_frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
192
        mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
193
        if sanitize:
194
            # sanitize the updated molecule
195
            try:
196
                Chem.SanitizeMol(mol)
197
            except ValueError:
198
                return None
199
200
    if relax_iter > 0:
201
        if not UFFHasAllMoleculeParams(mol):
202
            warnings.warn('UFF parameters not available for all atoms. '
203
                          'Returning None.')
204
            return None
205
206
        try:
207
            uff_relax(mol, relax_iter)
208
            if sanitize:
209
                # sanitize the updated molecule
210
                Chem.SanitizeMol(mol)
211
        except (RuntimeError, ValueError) as e:
212
            return None
213
214
    return mol
215
216
217
def uff_relax(mol, max_iter=200):
218
    """
219
    Uses RDKit's universal force field (UFF) implementation to optimize a
220
    molecule.
221
    """
222
    more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter)
223
    if more_iterations_required:
224
        warnings.warn(f'Maximum number of FF iterations reached. '
225
                      f'Returning molecule after {max_iter} relaxation steps.')
226
    return more_iterations_required
227
228
229
def filter_rd_mol(rdmol):
230
    """
231
    Filter out RDMols if they have a 3-3 ring intersection
232
    adapted from:
233
    https://github.com/luost26/3D-Generative-SBDD/blob/main/utils/chem.py
234
    """
235
    ring_info = rdmol.GetRingInfo()
236
    ring_info.AtomRings()
237
    rings = [set(r) for r in ring_info.AtomRings()]
238
239
    # 3-3 ring intersection
240
    for i, ring_a in enumerate(rings):
241
        if len(ring_a) != 3:
242
            continue
243
        for j, ring_b in enumerate(rings):
244
            if i <= j:
245
                continue
246
            inter = ring_a.intersection(ring_b)
247
            if (len(ring_b) == 3) and (len(inter) > 0): 
248
                return False
249
250
    return True