Diff of /inpaint.py [000000] .. [607087]

Switch to unified view

a b/inpaint.py
1
import argparse
2
from pathlib import Path
3
4
import numpy as np
5
import torch
6
import torch.nn.functional as F
7
from Bio.PDB import PDBParser
8
from rdkit import Chem
9
from torch_scatter import scatter_mean
10
from openbabel import openbabel
11
openbabel.obErrorLog.StopLogging()  # suppress OpenBabel messages
12
13
import utils
14
from lightning_modules import LigandPocketDDPM
15
from constants import FLOAT_TYPE, INT_TYPE
16
from analysis.molecule_builder import build_molecule, process_molecule
17
18
19
def prepare_from_sdf_files(sdf_files, atom_encoder):
20
21
    ligand_coords = []
22
    atom_one_hot = []
23
    for file in sdf_files:
24
        rdmol = Chem.SDMolSupplier(str(file), sanitize=False)[0]
25
        ligand_coords.append(
26
            torch.from_numpy(rdmol.GetConformer().GetPositions()).float()
27
        )
28
        types = torch.tensor([atom_encoder[a.GetSymbol()] for a in rdmol.GetAtoms()])
29
        atom_one_hot.append(
30
            F.one_hot(types, num_classes=len(atom_encoder))
31
        )
32
33
    return torch.cat(ligand_coords, dim=0), torch.cat(atom_one_hot, dim=0)
34
35
36
def prepare_ligand_from_pdb(biopython_atoms, atom_encoder):
37
38
    coord = torch.tensor(np.array([a.get_coord()
39
                                   for a in biopython_atoms]), dtype=FLOAT_TYPE)
40
    types = torch.tensor([atom_encoder[a.element.capitalize()]
41
                          for a in biopython_atoms])
42
    one_hot = F.one_hot(types, num_classes=len(atom_encoder))
43
44
    return coord, one_hot
45
46
47
def prepare_substructure(ref_ligand, fix_atoms, pdb_model):
48
49
    if fix_atoms[0].endswith(".sdf"):
50
        # ligand as sdf file
51
        coord, one_hot = prepare_from_sdf_files(fix_atoms, model.lig_type_encoder)
52
53
    else:
54
        # ligand contained in PDB; given in <chain>:<resi> format
55
        chain, resi = ref_ligand.split(':')
56
        ligand = utils.get_residue_with_resi(pdb_model[chain], int(resi))
57
        fixed_atoms = [a for a in ligand.get_atoms() if a.get_name() in set(fix_atoms)]
58
        coord, one_hot = prepare_ligand_from_pdb(fixed_atoms, model.lig_type_encoder)
59
60
    return coord, one_hot
61
62
63
def inpaint_ligand(model, pdb_file, n_samples, ligand, fix_atoms,
64
                   add_n_nodes=None, center='ligand', sanitize=False,
65
                   largest_frag=False, relax_iter=0, timesteps=None,
66
                   resamplings=1, save_traj=False):
67
    """
68
    Generate ligands given a pocket
69
    Args:
70
        model: Lightning model
71
        pdb_file: PDB filename
72
        n_samples: number of samples
73
        ligand: reference ligand given in <chain>:<resi> format if the ligand is
74
                contained in the PDB file, or path to an SDF file that
75
                contains the ligand; used to define the pocket
76
        fix_atoms: ligand atoms that should be fixed, e.g. "C1 N6 C5 C12"
77
        center: 'ligand' or 'pocket'
78
        add_n_nodes: number of ligand nodes to add, sampled randomly if 'None'
79
        sanitize: whether to sanitize molecules or not
80
        largest_frag: only return the largest fragment
81
        relax_iter: number of force field optimization steps
82
        timesteps: number of denoising steps, use training value if None
83
        resamplings: number of resampling iterations
84
        save_traj: save intermediate states to visualize a denoising trajectory
85
    Returns:
86
        list of molecules
87
    """
88
    if save_traj and n_samples > 1:
89
        raise NotImplementedError("Can only visualize trajectory with "
90
                                  "n_samples=1.")
91
    frames = timesteps if save_traj else 1
92
    sanitize = False if save_traj else sanitize
93
    relax_iter = 0 if save_traj else relax_iter
94
    largest_frag = False if save_traj else largest_frag
95
96
    # Load PDB
97
    pdb_model = PDBParser(QUIET=True).get_structure('', pdb_file)[0]
98
99
    # Define pocket based on reference ligand
100
    residues = utils.get_pocket_from_ligand(pdb_model, ligand)
101
    pocket = model.prepare_pocket(residues, repeats=n_samples)
102
103
    # Get fixed ligand substructure
104
    x_fixed, one_hot_fixed = prepare_substructure(ligand, fix_atoms, pdb_model)
105
    n_fixed = len(x_fixed)
106
107
    if add_n_nodes is None:
108
        num_nodes_lig = model.ddpm.size_distribution.sample_conditional(
109
            n1=None, n2=pocket['size'])
110
        num_nodes_lig = torch.clamp(num_nodes_lig, min=n_fixed)
111
    else:
112
        num_nodes_lig = torch.ones(n_samples, dtype=int) * n_fixed + add_n_nodes
113
114
    ligand_mask = utils.num_nodes_to_batch_mask(
115
        len(num_nodes_lig), num_nodes_lig, model.device)
116
117
    ligand = {
118
        'x': torch.zeros((len(ligand_mask), model.x_dims),
119
                         device=model.device, dtype=FLOAT_TYPE),
120
        'one_hot': torch.zeros((len(ligand_mask), model.atom_nf),
121
                               device=model.device, dtype=FLOAT_TYPE),
122
        'size': num_nodes_lig,
123
        'mask': ligand_mask
124
    }
125
126
    # fill in fixed atoms
127
    lig_fixed = torch.zeros_like(ligand_mask)
128
    for i in range(n_samples):
129
        sele = (ligand_mask == i)
130
131
        x_new = ligand['x'][sele]
132
        x_new[:n_fixed] = x_fixed
133
        ligand['x'][sele] = x_new
134
135
        h_new = ligand['one_hot'][sele]
136
        h_new[:n_fixed] = one_hot_fixed
137
        ligand['one_hot'][sele] = h_new
138
139
        fixed_new = lig_fixed[sele]
140
        fixed_new[:n_fixed] = 1
141
        lig_fixed[sele] = fixed_new
142
143
    # Pocket's center of mass
144
    pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0)
145
146
    # Run sampling
147
    xh_lig, xh_pocket, lig_mask, pocket_mask = model.ddpm.inpaint(
148
        ligand, pocket, lig_fixed, center=center,
149
        resamplings=resamplings, timesteps=timesteps, return_frames=frames)
150
151
    # Treat intermediate states as molecules for downstream processing
152
    if save_traj:
153
        xh_lig = utils.reverse_tensor(xh_lig)
154
        xh_pocket = utils.reverse_tensor(xh_pocket)
155
156
        lig_mask = torch.arange(xh_lig.size(0), device=model.device
157
                                ).repeat_interleave(len(lig_mask))
158
        pocket_mask = torch.arange(xh_pocket.size(0), device=model.device
159
                                   ).repeat_interleave(len(pocket_mask))
160
161
        xh_lig = xh_lig.view(-1, xh_lig.size(2))
162
        xh_pocket = xh_pocket.view(-1, xh_pocket.size(2))
163
164
    # Move generated molecule back to the original pocket position
165
    pocket_com_after = scatter_mean(xh_pocket[:, :model.x_dims], pocket_mask, dim=0)
166
167
    xh_pocket[:, :model.x_dims] += \
168
        (pocket_com_before - pocket_com_after)[pocket_mask]
169
    xh_lig[:, :model.x_dims] += \
170
        (pocket_com_before - pocket_com_after)[lig_mask]
171
172
    # Build mol objects
173
    x = xh_lig[:, :model.x_dims].detach().cpu()
174
    atom_type = xh_lig[:, model.x_dims:].argmax(1).detach().cpu()
175
176
    molecules = []
177
    for mol_pc in zip(utils.batch_to_list(x, lig_mask),
178
                      utils.batch_to_list(atom_type, lig_mask)):
179
180
        mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True)
181
        mol = process_molecule(mol,
182
                               add_hydrogens=False,
183
                               sanitize=sanitize,
184
                               relax_iter=relax_iter,
185
                               largest_frag=largest_frag)
186
        if mol is not None:
187
            molecules.append(mol)
188
189
    return molecules
190
191
192
if __name__ == "__main__":
193
194
    parser = argparse.ArgumentParser()
195
    parser.add_argument('checkpoint', type=Path)
196
    parser.add_argument('--pdbfile', type=str)
197
    parser.add_argument('--ref_ligand', type=str, default=None)
198
    parser.add_argument('--fix_atoms', type=str, nargs='+', default=None)
199
    parser.add_argument('--center', type=str, default='ligand', choices={'ligand', 'pocket'})
200
    parser.add_argument('--outfile', type=Path)
201
    parser.add_argument('--n_samples', type=int, default=20)
202
    parser.add_argument('--add_n_nodes', type=int, default=None)
203
    parser.add_argument('--relax', action='store_true')
204
    parser.add_argument('--sanitize', action='store_true')
205
    parser.add_argument('--resamplings', type=int, default=20)
206
    parser.add_argument('--timesteps', type=int, default=50)
207
    parser.add_argument('--save_traj', action='store_true')
208
    args = parser.parse_args()
209
210
    pdb_id = Path(args.pdbfile).stem
211
212
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
213
214
    # Load model
215
    model = LigandPocketDDPM.load_from_checkpoint(
216
        args.checkpoint, map_location=device)
217
    model = model.to(device)
218
219
    molecules = inpaint_ligand(model, args.pdbfile, args.n_samples,
220
                               args.ref_ligand, args.fix_atoms,
221
                               args.add_n_nodes, center=args.center,
222
                               sanitize=args.sanitize,
223
                               largest_frag=False,
224
                               relax_iter=(200 if args.relax else 0),
225
                               timesteps=args.timesteps,
226
                               resamplings=args.resamplings,
227
                               save_traj=args.save_traj)
228
229
    # Make SDF files
230
    utils.write_sdf_file(args.outfile, molecules)