import argparse
from argparse import Namespace
from pathlib import Path
import warnings
import torch
import pytorch_lightning as pl
import yaml
import numpy as np
from lightning_modules import LigandPocketDDPM
def merge_args_and_yaml(args, config_dict):
arg_dict = args.__dict__
for key, value in config_dict.items():
if key in arg_dict:
warnings.warn(f"Command line argument '{key}' (value: "
f"{arg_dict[key]}) will be overwritten with value "
f"{value} provided in the config file.")
if isinstance(value, dict):
arg_dict[key] = Namespace(**value)
else:
arg_dict[key] = value
return args
def merge_configs(config, resume_config):
for key, value in resume_config.items():
if isinstance(value, Namespace):
value = value.__dict__
if key in config and config[key] != value:
warnings.warn(f"Config parameter '{key}' (value: "
f"{config[key]}) will be overwritten with value "
f"{value} from the checkpoint.")
config[key] = value
return config
# ------------------------------------------------------------------------------
# Training
# ______________________________________________________________________________
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument('--config', type=str, required=True)
p.add_argument('--resume', type=str, default=None)
args = p.parse_args()
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
assert 'resume' not in config
# Get main config
ckpt_path = None if args.resume is None else Path(args.resume)
if args.resume is not None:
resume_config = torch.load(
ckpt_path, map_location=torch.device('cpu'))['hyper_parameters']
config = merge_configs(config, resume_config)
args = merge_args_and_yaml(args, config)
out_dir = Path(args.logdir, args.run_name)
histogram_file = Path(args.datadir, 'size_distribution.npy')
histogram = np.load(histogram_file).tolist()
pl_module = LigandPocketDDPM(
outdir=out_dir,
dataset=args.dataset,
datadir=args.datadir,
batch_size=args.batch_size,
lr=args.lr,
egnn_params=args.egnn_params,
diffusion_params=args.diffusion_params,
num_workers=args.num_workers,
augment_noise=args.augment_noise,
augment_rotation=args.augment_rotation,
clip_grad=args.clip_grad,
eval_epochs=args.eval_epochs,
eval_params=args.eval_params,
visualize_sample_epoch=args.visualize_sample_epoch,
visualize_chain_epoch=args.visualize_chain_epoch,
auxiliary_loss=args.auxiliary_loss,
loss_params=args.loss_params,
mode=args.mode,
node_histogram=histogram,
pocket_representation=args.pocket_representation,
virtual_nodes=args.virtual_nodes
)
logger = pl.loggers.WandbLogger(
save_dir=args.logdir,
project='ligand-pocket-ddpm',
group=args.wandb_params.group,
name=args.run_name,
id=args.run_name,
resume='must' if args.resume is not None else False,
entity=args.wandb_params.entity,
mode=args.wandb_params.mode,
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=Path(out_dir, 'checkpoints'),
filename="best-model-epoch={epoch:02d}",
monitor="loss/val",
save_top_k=1,
save_last=True,
mode="min",
)
trainer = pl.Trainer(
max_epochs=args.n_epochs,
logger=logger,
callbacks=[checkpoint_callback],
enable_progress_bar=args.enable_progress_bar,
num_sanity_val_steps=args.num_sanity_val_steps,
accelerator='gpu', devices=args.gpus,
strategy=('ddp' if args.gpus > 1 else None)
)
trainer.fit(model=pl_module, ckpt_path=ckpt_path)