|
a |
|
b/inpaint.py |
|
|
1 |
import argparse |
|
|
2 |
from pathlib import Path |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
from Bio.PDB import PDBParser |
|
|
8 |
from rdkit import Chem |
|
|
9 |
from torch_scatter import scatter_mean |
|
|
10 |
from openbabel import openbabel |
|
|
11 |
openbabel.obErrorLog.StopLogging() # suppress OpenBabel messages |
|
|
12 |
|
|
|
13 |
import utils |
|
|
14 |
from lightning_modules import LigandPocketDDPM |
|
|
15 |
from constants import FLOAT_TYPE, INT_TYPE |
|
|
16 |
from analysis.molecule_builder import build_molecule, process_molecule |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def prepare_from_sdf_files(sdf_files, atom_encoder): |
|
|
20 |
|
|
|
21 |
ligand_coords = [] |
|
|
22 |
atom_one_hot = [] |
|
|
23 |
for file in sdf_files: |
|
|
24 |
rdmol = Chem.SDMolSupplier(str(file), sanitize=False)[0] |
|
|
25 |
ligand_coords.append( |
|
|
26 |
torch.from_numpy(rdmol.GetConformer().GetPositions()).float() |
|
|
27 |
) |
|
|
28 |
types = torch.tensor([atom_encoder[a.GetSymbol()] for a in rdmol.GetAtoms()]) |
|
|
29 |
atom_one_hot.append( |
|
|
30 |
F.one_hot(types, num_classes=len(atom_encoder)) |
|
|
31 |
) |
|
|
32 |
|
|
|
33 |
return torch.cat(ligand_coords, dim=0), torch.cat(atom_one_hot, dim=0) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def prepare_ligand_from_pdb(biopython_atoms, atom_encoder): |
|
|
37 |
|
|
|
38 |
coord = torch.tensor(np.array([a.get_coord() |
|
|
39 |
for a in biopython_atoms]), dtype=FLOAT_TYPE) |
|
|
40 |
types = torch.tensor([atom_encoder[a.element.capitalize()] |
|
|
41 |
for a in biopython_atoms]) |
|
|
42 |
one_hot = F.one_hot(types, num_classes=len(atom_encoder)) |
|
|
43 |
|
|
|
44 |
return coord, one_hot |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def prepare_substructure(ref_ligand, fix_atoms, pdb_model): |
|
|
48 |
|
|
|
49 |
if fix_atoms[0].endswith(".sdf"): |
|
|
50 |
# ligand as sdf file |
|
|
51 |
coord, one_hot = prepare_from_sdf_files(fix_atoms, model.lig_type_encoder) |
|
|
52 |
|
|
|
53 |
else: |
|
|
54 |
# ligand contained in PDB; given in <chain>:<resi> format |
|
|
55 |
chain, resi = ref_ligand.split(':') |
|
|
56 |
ligand = utils.get_residue_with_resi(pdb_model[chain], int(resi)) |
|
|
57 |
fixed_atoms = [a for a in ligand.get_atoms() if a.get_name() in set(fix_atoms)] |
|
|
58 |
coord, one_hot = prepare_ligand_from_pdb(fixed_atoms, model.lig_type_encoder) |
|
|
59 |
|
|
|
60 |
return coord, one_hot |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
def inpaint_ligand(model, pdb_file, n_samples, ligand, fix_atoms, |
|
|
64 |
add_n_nodes=None, center='ligand', sanitize=False, |
|
|
65 |
largest_frag=False, relax_iter=0, timesteps=None, |
|
|
66 |
resamplings=1, save_traj=False): |
|
|
67 |
""" |
|
|
68 |
Generate ligands given a pocket |
|
|
69 |
Args: |
|
|
70 |
model: Lightning model |
|
|
71 |
pdb_file: PDB filename |
|
|
72 |
n_samples: number of samples |
|
|
73 |
ligand: reference ligand given in <chain>:<resi> format if the ligand is |
|
|
74 |
contained in the PDB file, or path to an SDF file that |
|
|
75 |
contains the ligand; used to define the pocket |
|
|
76 |
fix_atoms: ligand atoms that should be fixed, e.g. "C1 N6 C5 C12" |
|
|
77 |
center: 'ligand' or 'pocket' |
|
|
78 |
add_n_nodes: number of ligand nodes to add, sampled randomly if 'None' |
|
|
79 |
sanitize: whether to sanitize molecules or not |
|
|
80 |
largest_frag: only return the largest fragment |
|
|
81 |
relax_iter: number of force field optimization steps |
|
|
82 |
timesteps: number of denoising steps, use training value if None |
|
|
83 |
resamplings: number of resampling iterations |
|
|
84 |
save_traj: save intermediate states to visualize a denoising trajectory |
|
|
85 |
Returns: |
|
|
86 |
list of molecules |
|
|
87 |
""" |
|
|
88 |
if save_traj and n_samples > 1: |
|
|
89 |
raise NotImplementedError("Can only visualize trajectory with " |
|
|
90 |
"n_samples=1.") |
|
|
91 |
frames = timesteps if save_traj else 1 |
|
|
92 |
sanitize = False if save_traj else sanitize |
|
|
93 |
relax_iter = 0 if save_traj else relax_iter |
|
|
94 |
largest_frag = False if save_traj else largest_frag |
|
|
95 |
|
|
|
96 |
# Load PDB |
|
|
97 |
pdb_model = PDBParser(QUIET=True).get_structure('', pdb_file)[0] |
|
|
98 |
|
|
|
99 |
# Define pocket based on reference ligand |
|
|
100 |
residues = utils.get_pocket_from_ligand(pdb_model, ligand) |
|
|
101 |
pocket = model.prepare_pocket(residues, repeats=n_samples) |
|
|
102 |
|
|
|
103 |
# Get fixed ligand substructure |
|
|
104 |
x_fixed, one_hot_fixed = prepare_substructure(ligand, fix_atoms, pdb_model) |
|
|
105 |
n_fixed = len(x_fixed) |
|
|
106 |
|
|
|
107 |
if add_n_nodes is None: |
|
|
108 |
num_nodes_lig = model.ddpm.size_distribution.sample_conditional( |
|
|
109 |
n1=None, n2=pocket['size']) |
|
|
110 |
num_nodes_lig = torch.clamp(num_nodes_lig, min=n_fixed) |
|
|
111 |
else: |
|
|
112 |
num_nodes_lig = torch.ones(n_samples, dtype=int) * n_fixed + add_n_nodes |
|
|
113 |
|
|
|
114 |
ligand_mask = utils.num_nodes_to_batch_mask( |
|
|
115 |
len(num_nodes_lig), num_nodes_lig, model.device) |
|
|
116 |
|
|
|
117 |
ligand = { |
|
|
118 |
'x': torch.zeros((len(ligand_mask), model.x_dims), |
|
|
119 |
device=model.device, dtype=FLOAT_TYPE), |
|
|
120 |
'one_hot': torch.zeros((len(ligand_mask), model.atom_nf), |
|
|
121 |
device=model.device, dtype=FLOAT_TYPE), |
|
|
122 |
'size': num_nodes_lig, |
|
|
123 |
'mask': ligand_mask |
|
|
124 |
} |
|
|
125 |
|
|
|
126 |
# fill in fixed atoms |
|
|
127 |
lig_fixed = torch.zeros_like(ligand_mask) |
|
|
128 |
for i in range(n_samples): |
|
|
129 |
sele = (ligand_mask == i) |
|
|
130 |
|
|
|
131 |
x_new = ligand['x'][sele] |
|
|
132 |
x_new[:n_fixed] = x_fixed |
|
|
133 |
ligand['x'][sele] = x_new |
|
|
134 |
|
|
|
135 |
h_new = ligand['one_hot'][sele] |
|
|
136 |
h_new[:n_fixed] = one_hot_fixed |
|
|
137 |
ligand['one_hot'][sele] = h_new |
|
|
138 |
|
|
|
139 |
fixed_new = lig_fixed[sele] |
|
|
140 |
fixed_new[:n_fixed] = 1 |
|
|
141 |
lig_fixed[sele] = fixed_new |
|
|
142 |
|
|
|
143 |
# Pocket's center of mass |
|
|
144 |
pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
|
|
145 |
|
|
|
146 |
# Run sampling |
|
|
147 |
xh_lig, xh_pocket, lig_mask, pocket_mask = model.ddpm.inpaint( |
|
|
148 |
ligand, pocket, lig_fixed, center=center, |
|
|
149 |
resamplings=resamplings, timesteps=timesteps, return_frames=frames) |
|
|
150 |
|
|
|
151 |
# Treat intermediate states as molecules for downstream processing |
|
|
152 |
if save_traj: |
|
|
153 |
xh_lig = utils.reverse_tensor(xh_lig) |
|
|
154 |
xh_pocket = utils.reverse_tensor(xh_pocket) |
|
|
155 |
|
|
|
156 |
lig_mask = torch.arange(xh_lig.size(0), device=model.device |
|
|
157 |
).repeat_interleave(len(lig_mask)) |
|
|
158 |
pocket_mask = torch.arange(xh_pocket.size(0), device=model.device |
|
|
159 |
).repeat_interleave(len(pocket_mask)) |
|
|
160 |
|
|
|
161 |
xh_lig = xh_lig.view(-1, xh_lig.size(2)) |
|
|
162 |
xh_pocket = xh_pocket.view(-1, xh_pocket.size(2)) |
|
|
163 |
|
|
|
164 |
# Move generated molecule back to the original pocket position |
|
|
165 |
pocket_com_after = scatter_mean(xh_pocket[:, :model.x_dims], pocket_mask, dim=0) |
|
|
166 |
|
|
|
167 |
xh_pocket[:, :model.x_dims] += \ |
|
|
168 |
(pocket_com_before - pocket_com_after)[pocket_mask] |
|
|
169 |
xh_lig[:, :model.x_dims] += \ |
|
|
170 |
(pocket_com_before - pocket_com_after)[lig_mask] |
|
|
171 |
|
|
|
172 |
# Build mol objects |
|
|
173 |
x = xh_lig[:, :model.x_dims].detach().cpu() |
|
|
174 |
atom_type = xh_lig[:, model.x_dims:].argmax(1).detach().cpu() |
|
|
175 |
|
|
|
176 |
molecules = [] |
|
|
177 |
for mol_pc in zip(utils.batch_to_list(x, lig_mask), |
|
|
178 |
utils.batch_to_list(atom_type, lig_mask)): |
|
|
179 |
|
|
|
180 |
mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True) |
|
|
181 |
mol = process_molecule(mol, |
|
|
182 |
add_hydrogens=False, |
|
|
183 |
sanitize=sanitize, |
|
|
184 |
relax_iter=relax_iter, |
|
|
185 |
largest_frag=largest_frag) |
|
|
186 |
if mol is not None: |
|
|
187 |
molecules.append(mol) |
|
|
188 |
|
|
|
189 |
return molecules |
|
|
190 |
|
|
|
191 |
|
|
|
192 |
if __name__ == "__main__": |
|
|
193 |
|
|
|
194 |
parser = argparse.ArgumentParser() |
|
|
195 |
parser.add_argument('checkpoint', type=Path) |
|
|
196 |
parser.add_argument('--pdbfile', type=str) |
|
|
197 |
parser.add_argument('--ref_ligand', type=str, default=None) |
|
|
198 |
parser.add_argument('--fix_atoms', type=str, nargs='+', default=None) |
|
|
199 |
parser.add_argument('--center', type=str, default='ligand', choices={'ligand', 'pocket'}) |
|
|
200 |
parser.add_argument('--outfile', type=Path) |
|
|
201 |
parser.add_argument('--n_samples', type=int, default=20) |
|
|
202 |
parser.add_argument('--add_n_nodes', type=int, default=None) |
|
|
203 |
parser.add_argument('--relax', action='store_true') |
|
|
204 |
parser.add_argument('--sanitize', action='store_true') |
|
|
205 |
parser.add_argument('--resamplings', type=int, default=20) |
|
|
206 |
parser.add_argument('--timesteps', type=int, default=50) |
|
|
207 |
parser.add_argument('--save_traj', action='store_true') |
|
|
208 |
args = parser.parse_args() |
|
|
209 |
|
|
|
210 |
pdb_id = Path(args.pdbfile).stem |
|
|
211 |
|
|
|
212 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
213 |
|
|
|
214 |
# Load model |
|
|
215 |
model = LigandPocketDDPM.load_from_checkpoint( |
|
|
216 |
args.checkpoint, map_location=device) |
|
|
217 |
model = model.to(device) |
|
|
218 |
|
|
|
219 |
molecules = inpaint_ligand(model, args.pdbfile, args.n_samples, |
|
|
220 |
args.ref_ligand, args.fix_atoms, |
|
|
221 |
args.add_n_nodes, center=args.center, |
|
|
222 |
sanitize=args.sanitize, |
|
|
223 |
largest_frag=False, |
|
|
224 |
relax_iter=(200 if args.relax else 0), |
|
|
225 |
timesteps=args.timesteps, |
|
|
226 |
resamplings=args.resamplings, |
|
|
227 |
save_traj=args.save_traj) |
|
|
228 |
|
|
|
229 |
# Make SDF files |
|
|
230 |
utils.write_sdf_file(args.outfile, molecules) |