--- a +++ b/analysis/molecule_builder.py @@ -0,0 +1,250 @@ +import warnings +import tempfile + +import torch +import numpy as np +from rdkit import Chem +from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule, UFFHasAllMoleculeParams +import openbabel + +import utils +from constants import bonds1, bonds2, bonds3, margin1, margin2, margin3, \ + bond_dict + + +def get_bond_order(atom1, atom2, distance): + distance = 100 * distance # We change the metric + + if atom1 in bonds3 and atom2 in bonds3[atom1] and distance < bonds3[atom1][atom2] + margin3: + return 3 # Triple + + if atom1 in bonds2 and atom2 in bonds2[atom1] and distance < bonds2[atom1][atom2] + margin2: + return 2 # Double + + if atom1 in bonds1 and atom2 in bonds1[atom1] and distance < bonds1[atom1][atom2] + margin1: + return 1 # Single + + return 0 # No bond + + +def get_bond_order_batch(atoms1, atoms2, distances, dataset_info): + if isinstance(atoms1, np.ndarray): + atoms1 = torch.from_numpy(atoms1) + if isinstance(atoms2, np.ndarray): + atoms2 = torch.from_numpy(atoms2) + if isinstance(distances, np.ndarray): + distances = torch.from_numpy(distances) + + distances = 100 * distances # We change the metric + + bonds1 = torch.tensor(dataset_info['bonds1'], device=atoms1.device) + bonds2 = torch.tensor(dataset_info['bonds2'], device=atoms1.device) + bonds3 = torch.tensor(dataset_info['bonds3'], device=atoms1.device) + + bond_types = torch.zeros_like(atoms1) # 0: No bond + + # Single + bond_types[distances < bonds1[atoms1, atoms2] + margin1] = 1 + + # Double (note that already assigned single bonds will be overwritten) + bond_types[distances < bonds2[atoms1, atoms2] + margin2] = 2 + + # Triple + bond_types[distances < bonds3[atoms1, atoms2] + margin3] = 3 + + return bond_types + + +def make_mol_openbabel(positions, atom_types, atom_decoder): + """ + Build an RDKit molecule using openbabel for creating bonds + Args: + positions: N x 3 + atom_types: N + atom_decoder: maps indices to atom types + Returns: + rdkit molecule + """ + atom_types = [atom_decoder[x] for x in atom_types] + + with tempfile.NamedTemporaryFile() as tmp: + tmp_file = tmp.name + + # Write xyz file + utils.write_xyz_file(positions, atom_types, tmp_file) + + # Convert to sdf file with openbabel + # openbabel will add bonds + obConversion = openbabel.OBConversion() + obConversion.SetInAndOutFormats("xyz", "sdf") + ob_mol = openbabel.OBMol() + obConversion.ReadFile(ob_mol, tmp_file) + + obConversion.WriteFile(ob_mol, tmp_file) + + # Read sdf file with RDKit + tmp_mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0] + + # Build new molecule. This is a workaround to remove radicals. + mol = Chem.RWMol() + for atom in tmp_mol.GetAtoms(): + mol.AddAtom(Chem.Atom(atom.GetSymbol())) + mol.AddConformer(tmp_mol.GetConformer(0)) + + for bond in tmp_mol.GetBonds(): + mol.AddBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), + bond.GetBondType()) + + return mol + + +def make_mol_edm(positions, atom_types, dataset_info, add_coords): + """ + Equivalent to EDM's way of building RDKit molecules + """ + n = len(positions) + + # (X, A, E): atom_types, adjacency matrix, edge_types + # X: N (int) + # A: N x N (bool) -> (binary adjacency matrix) + # E: N x N (int) -> (bond type, 0 if no bond) + pos = positions.unsqueeze(0) # add batch dim + dists = torch.cdist(pos, pos, p=2).squeeze(0).view(-1) # remove batch dim & flatten + atoms1, atoms2 = torch.cartesian_prod(atom_types, atom_types).T + E_full = get_bond_order_batch(atoms1, atoms2, dists, dataset_info).view(n, n) + E = torch.tril(E_full, diagonal=-1) # Warning: the graph should be DIRECTED + A = E.bool() + X = atom_types + + mol = Chem.RWMol() + for atom in X: + a = Chem.Atom(dataset_info["atom_decoder"][atom.item()]) + mol.AddAtom(a) + + all_bonds = torch.nonzero(A) + for bond in all_bonds: + mol.AddBond(bond[0].item(), bond[1].item(), + bond_dict[E[bond[0], bond[1]].item()]) + + if add_coords: + conf = Chem.Conformer(mol.GetNumAtoms()) + for i in range(mol.GetNumAtoms()): + conf.SetAtomPosition(i, (positions[i, 0].item(), + positions[i, 1].item(), + positions[i, 2].item())) + mol.AddConformer(conf) + + return mol + + +def build_molecule(positions, atom_types, dataset_info, add_coords=False, + use_openbabel=True): + """ + Build RDKit molecule + Args: + positions: N x 3 + atom_types: N + dataset_info: dict + add_coords: Add conformer to mol (always added if use_openbabel=True) + use_openbabel: use OpenBabel to create bonds + Returns: + RDKit molecule + """ + if use_openbabel: + mol = make_mol_openbabel(positions, atom_types, + dataset_info["atom_decoder"]) + else: + mol = make_mol_edm(positions, atom_types, dataset_info, add_coords) + + return mol + + +def process_molecule(rdmol, add_hydrogens=False, sanitize=False, relax_iter=0, + largest_frag=False): + """ + Apply filters to an RDKit molecule. Makes a copy first. + Args: + rdmol: rdkit molecule + add_hydrogens + sanitize + relax_iter: maximum number of UFF optimization iterations + largest_frag: filter out the largest fragment in a set of disjoint + molecules + Returns: + RDKit molecule or None if it does not pass the filters + """ + + # Create a copy + mol = Chem.Mol(rdmol) + + if sanitize: + try: + Chem.SanitizeMol(mol) + except ValueError: + warnings.warn('Sanitization failed. Returning None.') + return None + + if add_hydrogens: + mol = Chem.AddHs(mol, addCoords=(len(mol.GetConformers()) > 0)) + + if largest_frag: + mol_frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) + mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) + if sanitize: + # sanitize the updated molecule + try: + Chem.SanitizeMol(mol) + except ValueError: + return None + + if relax_iter > 0: + if not UFFHasAllMoleculeParams(mol): + warnings.warn('UFF parameters not available for all atoms. ' + 'Returning None.') + return None + + try: + uff_relax(mol, relax_iter) + if sanitize: + # sanitize the updated molecule + Chem.SanitizeMol(mol) + except (RuntimeError, ValueError) as e: + return None + + return mol + + +def uff_relax(mol, max_iter=200): + """ + Uses RDKit's universal force field (UFF) implementation to optimize a + molecule. + """ + more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter) + if more_iterations_required: + warnings.warn(f'Maximum number of FF iterations reached. ' + f'Returning molecule after {max_iter} relaxation steps.') + return more_iterations_required + + +def filter_rd_mol(rdmol): + """ + Filter out RDMols if they have a 3-3 ring intersection + adapted from: + https://github.com/luost26/3D-Generative-SBDD/blob/main/utils/chem.py + """ + ring_info = rdmol.GetRingInfo() + ring_info.AtomRings() + rings = [set(r) for r in ring_info.AtomRings()] + + # 3-3 ring intersection + for i, ring_a in enumerate(rings): + if len(ring_a) != 3: + continue + for j, ring_b in enumerate(rings): + if i <= j: + continue + inter = ring_a.intersection(ring_b) + if (len(ring_b) == 3) and (len(inter) > 0): + return False + + return True