|
a |
|
b/train.py |
|
|
1 |
import argparse |
|
|
2 |
from argparse import Namespace |
|
|
3 |
from pathlib import Path |
|
|
4 |
import warnings |
|
|
5 |
|
|
|
6 |
import torch |
|
|
7 |
import pytorch_lightning as pl |
|
|
8 |
import yaml |
|
|
9 |
import numpy as np |
|
|
10 |
|
|
|
11 |
from lightning_modules import LigandPocketDDPM |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def merge_args_and_yaml(args, config_dict): |
|
|
15 |
arg_dict = args.__dict__ |
|
|
16 |
for key, value in config_dict.items(): |
|
|
17 |
if key in arg_dict: |
|
|
18 |
warnings.warn(f"Command line argument '{key}' (value: " |
|
|
19 |
f"{arg_dict[key]}) will be overwritten with value " |
|
|
20 |
f"{value} provided in the config file.") |
|
|
21 |
if isinstance(value, dict): |
|
|
22 |
arg_dict[key] = Namespace(**value) |
|
|
23 |
else: |
|
|
24 |
arg_dict[key] = value |
|
|
25 |
|
|
|
26 |
return args |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def merge_configs(config, resume_config): |
|
|
30 |
for key, value in resume_config.items(): |
|
|
31 |
if isinstance(value, Namespace): |
|
|
32 |
value = value.__dict__ |
|
|
33 |
if key in config and config[key] != value: |
|
|
34 |
warnings.warn(f"Config parameter '{key}' (value: " |
|
|
35 |
f"{config[key]}) will be overwritten with value " |
|
|
36 |
f"{value} from the checkpoint.") |
|
|
37 |
config[key] = value |
|
|
38 |
return config |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
# ------------------------------------------------------------------------------ |
|
|
42 |
# Training |
|
|
43 |
# ______________________________________________________________________________ |
|
|
44 |
if __name__ == "__main__": |
|
|
45 |
p = argparse.ArgumentParser() |
|
|
46 |
p.add_argument('--config', type=str, required=True) |
|
|
47 |
p.add_argument('--resume', type=str, default=None) |
|
|
48 |
args = p.parse_args() |
|
|
49 |
|
|
|
50 |
with open(args.config, 'r') as f: |
|
|
51 |
config = yaml.safe_load(f) |
|
|
52 |
|
|
|
53 |
assert 'resume' not in config |
|
|
54 |
|
|
|
55 |
# Get main config |
|
|
56 |
ckpt_path = None if args.resume is None else Path(args.resume) |
|
|
57 |
if args.resume is not None: |
|
|
58 |
resume_config = torch.load( |
|
|
59 |
ckpt_path, map_location=torch.device('cpu'))['hyper_parameters'] |
|
|
60 |
|
|
|
61 |
config = merge_configs(config, resume_config) |
|
|
62 |
|
|
|
63 |
args = merge_args_and_yaml(args, config) |
|
|
64 |
|
|
|
65 |
out_dir = Path(args.logdir, args.run_name) |
|
|
66 |
histogram_file = Path(args.datadir, 'size_distribution.npy') |
|
|
67 |
histogram = np.load(histogram_file).tolist() |
|
|
68 |
pl_module = LigandPocketDDPM( |
|
|
69 |
outdir=out_dir, |
|
|
70 |
dataset=args.dataset, |
|
|
71 |
datadir=args.datadir, |
|
|
72 |
batch_size=args.batch_size, |
|
|
73 |
lr=args.lr, |
|
|
74 |
egnn_params=args.egnn_params, |
|
|
75 |
diffusion_params=args.diffusion_params, |
|
|
76 |
num_workers=args.num_workers, |
|
|
77 |
augment_noise=args.augment_noise, |
|
|
78 |
augment_rotation=args.augment_rotation, |
|
|
79 |
clip_grad=args.clip_grad, |
|
|
80 |
eval_epochs=args.eval_epochs, |
|
|
81 |
eval_params=args.eval_params, |
|
|
82 |
visualize_sample_epoch=args.visualize_sample_epoch, |
|
|
83 |
visualize_chain_epoch=args.visualize_chain_epoch, |
|
|
84 |
auxiliary_loss=args.auxiliary_loss, |
|
|
85 |
loss_params=args.loss_params, |
|
|
86 |
mode=args.mode, |
|
|
87 |
node_histogram=histogram, |
|
|
88 |
pocket_representation=args.pocket_representation, |
|
|
89 |
virtual_nodes=args.virtual_nodes |
|
|
90 |
) |
|
|
91 |
|
|
|
92 |
logger = pl.loggers.WandbLogger( |
|
|
93 |
save_dir=args.logdir, |
|
|
94 |
project='ligand-pocket-ddpm', |
|
|
95 |
group=args.wandb_params.group, |
|
|
96 |
name=args.run_name, |
|
|
97 |
id=args.run_name, |
|
|
98 |
resume='must' if args.resume is not None else False, |
|
|
99 |
entity=args.wandb_params.entity, |
|
|
100 |
mode=args.wandb_params.mode, |
|
|
101 |
) |
|
|
102 |
|
|
|
103 |
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
|
104 |
dirpath=Path(out_dir, 'checkpoints'), |
|
|
105 |
filename="best-model-epoch={epoch:02d}", |
|
|
106 |
monitor="loss/val", |
|
|
107 |
save_top_k=1, |
|
|
108 |
save_last=True, |
|
|
109 |
mode="min", |
|
|
110 |
) |
|
|
111 |
|
|
|
112 |
trainer = pl.Trainer( |
|
|
113 |
max_epochs=args.n_epochs, |
|
|
114 |
logger=logger, |
|
|
115 |
callbacks=[checkpoint_callback], |
|
|
116 |
enable_progress_bar=args.enable_progress_bar, |
|
|
117 |
num_sanity_val_steps=args.num_sanity_val_steps, |
|
|
118 |
accelerator='gpu', devices=args.gpus, |
|
|
119 |
strategy=('ddp' if args.gpus > 1 else None) |
|
|
120 |
) |
|
|
121 |
|
|
|
122 |
trainer.fit(model=pl_module, ckpt_path=ckpt_path) |