|
a |
|
b/generate_ligands.py |
|
|
1 |
import argparse |
|
|
2 |
from pathlib import Path |
|
|
3 |
|
|
|
4 |
import torch |
|
|
5 |
from openbabel import openbabel |
|
|
6 |
openbabel.obErrorLog.StopLogging() # suppress OpenBabel messages |
|
|
7 |
|
|
|
8 |
import utils |
|
|
9 |
from lightning_modules import LigandPocketDDPM |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
if __name__ == "__main__": |
|
|
13 |
parser = argparse.ArgumentParser() |
|
|
14 |
parser.add_argument('checkpoint', type=Path) |
|
|
15 |
parser.add_argument('--pdbfile', type=str) |
|
|
16 |
parser.add_argument('--resi_list', type=str, nargs='+', default=None) |
|
|
17 |
parser.add_argument('--ref_ligand', type=str, default=None) |
|
|
18 |
parser.add_argument('--outfile', type=Path) |
|
|
19 |
parser.add_argument('--n_samples', type=int, default=20) |
|
|
20 |
parser.add_argument('--batch_size', type=int, default=None) |
|
|
21 |
parser.add_argument('--num_nodes_lig', type=int, default=None) |
|
|
22 |
parser.add_argument('--all_frags', action='store_true') |
|
|
23 |
parser.add_argument('--sanitize', action='store_true') |
|
|
24 |
parser.add_argument('--relax', action='store_true') |
|
|
25 |
parser.add_argument('--resamplings', type=int, default=10) |
|
|
26 |
parser.add_argument('--jump_length', type=int, default=1) |
|
|
27 |
parser.add_argument('--timesteps', type=int, default=None) |
|
|
28 |
args = parser.parse_args() |
|
|
29 |
|
|
|
30 |
pdb_id = Path(args.pdbfile).stem |
|
|
31 |
|
|
|
32 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
33 |
|
|
|
34 |
if args.batch_size is None: |
|
|
35 |
args.batch_size = args.n_samples |
|
|
36 |
assert args.n_samples % args.batch_size == 0 |
|
|
37 |
|
|
|
38 |
# Load model |
|
|
39 |
model = LigandPocketDDPM.load_from_checkpoint( |
|
|
40 |
args.checkpoint, map_location=device) |
|
|
41 |
model = model.to(device) |
|
|
42 |
|
|
|
43 |
if args.num_nodes_lig is not None: |
|
|
44 |
num_nodes_lig = torch.ones(args.n_samples, dtype=int) * \ |
|
|
45 |
args.num_nodes_lig |
|
|
46 |
else: |
|
|
47 |
num_nodes_lig = None |
|
|
48 |
|
|
|
49 |
molecules = [] |
|
|
50 |
for i in range(args.n_samples // args.batch_size): |
|
|
51 |
molecules_batch = model.generate_ligands( |
|
|
52 |
args.pdbfile, args.batch_size, args.resi_list, args.ref_ligand, |
|
|
53 |
num_nodes_lig, args.sanitize, largest_frag=not args.all_frags, |
|
|
54 |
relax_iter=(200 if args.relax else 0), |
|
|
55 |
resamplings=args.resamplings, jump_length=args.jump_length, |
|
|
56 |
timesteps=args.timesteps) |
|
|
57 |
molecules.extend(molecules_batch) |
|
|
58 |
|
|
|
59 |
# Make SDF files |
|
|
60 |
utils.write_sdf_file(args.outfile, molecules) |