Diff of /generate_ligands.py [000000] .. [607087]

Switch to unified view

a b/generate_ligands.py
1
import argparse
2
from pathlib import Path
3
4
import torch
5
from openbabel import openbabel
6
openbabel.obErrorLog.StopLogging()  # suppress OpenBabel messages
7
8
import utils
9
from lightning_modules import LigandPocketDDPM
10
11
12
if __name__ == "__main__":
13
    parser = argparse.ArgumentParser()
14
    parser.add_argument('checkpoint', type=Path)
15
    parser.add_argument('--pdbfile', type=str)
16
    parser.add_argument('--resi_list', type=str, nargs='+', default=None)
17
    parser.add_argument('--ref_ligand', type=str, default=None)
18
    parser.add_argument('--outfile', type=Path)
19
    parser.add_argument('--n_samples', type=int, default=20)
20
    parser.add_argument('--batch_size', type=int, default=None)
21
    parser.add_argument('--num_nodes_lig', type=int, default=None)
22
    parser.add_argument('--all_frags', action='store_true')
23
    parser.add_argument('--sanitize', action='store_true')
24
    parser.add_argument('--relax', action='store_true')
25
    parser.add_argument('--resamplings', type=int, default=10)
26
    parser.add_argument('--jump_length', type=int, default=1)
27
    parser.add_argument('--timesteps', type=int, default=None)
28
    args = parser.parse_args()
29
30
    pdb_id = Path(args.pdbfile).stem
31
32
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
34
    if args.batch_size is None:
35
        args.batch_size = args.n_samples
36
    assert args.n_samples % args.batch_size == 0
37
38
    # Load model
39
    model = LigandPocketDDPM.load_from_checkpoint(
40
        args.checkpoint, map_location=device)
41
    model = model.to(device)
42
43
    if args.num_nodes_lig is not None:
44
        num_nodes_lig = torch.ones(args.n_samples, dtype=int) * \
45
                        args.num_nodes_lig
46
    else:
47
        num_nodes_lig = None
48
49
    molecules = []
50
    for i in range(args.n_samples // args.batch_size):
51
        molecules_batch = model.generate_ligands(
52
            args.pdbfile, args.batch_size, args.resi_list, args.ref_ligand,
53
            num_nodes_lig, args.sanitize, largest_frag=not args.all_frags,
54
            relax_iter=(200 if args.relax else 0),
55
            resamplings=args.resamplings, jump_length=args.jump_length,
56
            timesteps=args.timesteps)
57
        molecules.extend(molecules_batch)
58
59
    # Make SDF files
60
    utils.write_sdf_file(args.outfile, molecules)