--- a +++ b/train.py @@ -0,0 +1,122 @@ +import argparse +from argparse import Namespace +from pathlib import Path +import warnings + +import torch +import pytorch_lightning as pl +import yaml +import numpy as np + +from lightning_modules import LigandPocketDDPM + + +def merge_args_and_yaml(args, config_dict): + arg_dict = args.__dict__ + for key, value in config_dict.items(): + if key in arg_dict: + warnings.warn(f"Command line argument '{key}' (value: " + f"{arg_dict[key]}) will be overwritten with value " + f"{value} provided in the config file.") + if isinstance(value, dict): + arg_dict[key] = Namespace(**value) + else: + arg_dict[key] = value + + return args + + +def merge_configs(config, resume_config): + for key, value in resume_config.items(): + if isinstance(value, Namespace): + value = value.__dict__ + if key in config and config[key] != value: + warnings.warn(f"Config parameter '{key}' (value: " + f"{config[key]}) will be overwritten with value " + f"{value} from the checkpoint.") + config[key] = value + return config + + +# ------------------------------------------------------------------------------ +# Training +# ______________________________________________________________________________ +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument('--config', type=str, required=True) + p.add_argument('--resume', type=str, default=None) + args = p.parse_args() + + with open(args.config, 'r') as f: + config = yaml.safe_load(f) + + assert 'resume' not in config + + # Get main config + ckpt_path = None if args.resume is None else Path(args.resume) + if args.resume is not None: + resume_config = torch.load( + ckpt_path, map_location=torch.device('cpu'))['hyper_parameters'] + + config = merge_configs(config, resume_config) + + args = merge_args_and_yaml(args, config) + + out_dir = Path(args.logdir, args.run_name) + histogram_file = Path(args.datadir, 'size_distribution.npy') + histogram = np.load(histogram_file).tolist() + pl_module = LigandPocketDDPM( + outdir=out_dir, + dataset=args.dataset, + datadir=args.datadir, + batch_size=args.batch_size, + lr=args.lr, + egnn_params=args.egnn_params, + diffusion_params=args.diffusion_params, + num_workers=args.num_workers, + augment_noise=args.augment_noise, + augment_rotation=args.augment_rotation, + clip_grad=args.clip_grad, + eval_epochs=args.eval_epochs, + eval_params=args.eval_params, + visualize_sample_epoch=args.visualize_sample_epoch, + visualize_chain_epoch=args.visualize_chain_epoch, + auxiliary_loss=args.auxiliary_loss, + loss_params=args.loss_params, + mode=args.mode, + node_histogram=histogram, + pocket_representation=args.pocket_representation, + virtual_nodes=args.virtual_nodes + ) + + logger = pl.loggers.WandbLogger( + save_dir=args.logdir, + project='ligand-pocket-ddpm', + group=args.wandb_params.group, + name=args.run_name, + id=args.run_name, + resume='must' if args.resume is not None else False, + entity=args.wandb_params.entity, + mode=args.wandb_params.mode, + ) + + checkpoint_callback = pl.callbacks.ModelCheckpoint( + dirpath=Path(out_dir, 'checkpoints'), + filename="best-model-epoch={epoch:02d}", + monitor="loss/val", + save_top_k=1, + save_last=True, + mode="min", + ) + + trainer = pl.Trainer( + max_epochs=args.n_epochs, + logger=logger, + callbacks=[checkpoint_callback], + enable_progress_bar=args.enable_progress_bar, + num_sanity_val_steps=args.num_sanity_val_steps, + accelerator='gpu', devices=args.gpus, + strategy=('ddp' if args.gpus > 1 else None) + ) + + trainer.fit(model=pl_module, ckpt_path=ckpt_path)