Diff of /generate_ligands.py [000000] .. [607087]

Switch to side-by-side view

--- a
+++ b/generate_ligands.py
@@ -0,0 +1,60 @@
+import argparse
+from pathlib import Path
+
+import torch
+from openbabel import openbabel
+openbabel.obErrorLog.StopLogging()  # suppress OpenBabel messages
+
+import utils
+from lightning_modules import LigandPocketDDPM
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('checkpoint', type=Path)
+    parser.add_argument('--pdbfile', type=str)
+    parser.add_argument('--resi_list', type=str, nargs='+', default=None)
+    parser.add_argument('--ref_ligand', type=str, default=None)
+    parser.add_argument('--outfile', type=Path)
+    parser.add_argument('--n_samples', type=int, default=20)
+    parser.add_argument('--batch_size', type=int, default=None)
+    parser.add_argument('--num_nodes_lig', type=int, default=None)
+    parser.add_argument('--all_frags', action='store_true')
+    parser.add_argument('--sanitize', action='store_true')
+    parser.add_argument('--relax', action='store_true')
+    parser.add_argument('--resamplings', type=int, default=10)
+    parser.add_argument('--jump_length', type=int, default=1)
+    parser.add_argument('--timesteps', type=int, default=None)
+    args = parser.parse_args()
+
+    pdb_id = Path(args.pdbfile).stem
+
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+    if args.batch_size is None:
+        args.batch_size = args.n_samples
+    assert args.n_samples % args.batch_size == 0
+
+    # Load model
+    model = LigandPocketDDPM.load_from_checkpoint(
+        args.checkpoint, map_location=device)
+    model = model.to(device)
+
+    if args.num_nodes_lig is not None:
+        num_nodes_lig = torch.ones(args.n_samples, dtype=int) * \
+                        args.num_nodes_lig
+    else:
+        num_nodes_lig = None
+
+    molecules = []
+    for i in range(args.n_samples // args.batch_size):
+        molecules_batch = model.generate_ligands(
+            args.pdbfile, args.batch_size, args.resi_list, args.ref_ligand,
+            num_nodes_lig, args.sanitize, largest_frag=not args.all_frags,
+            relax_iter=(200 if args.relax else 0),
+            resamplings=args.resamplings, jump_length=args.jump_length,
+            timesteps=args.timesteps)
+        molecules.extend(molecules_batch)
+
+    # Make SDF files
+    utils.write_sdf_file(args.outfile, molecules)