Diff of /train.py [000000] .. [607087]

Switch to unified view

a b/train.py
1
import argparse
2
from argparse import Namespace
3
from pathlib import Path
4
import warnings
5
6
import torch
7
import pytorch_lightning as pl
8
import yaml
9
import numpy as np
10
11
from lightning_modules import LigandPocketDDPM
12
13
14
def merge_args_and_yaml(args, config_dict):
15
    arg_dict = args.__dict__
16
    for key, value in config_dict.items():
17
        if key in arg_dict:
18
            warnings.warn(f"Command line argument '{key}' (value: "
19
                          f"{arg_dict[key]}) will be overwritten with value "
20
                          f"{value} provided in the config file.")
21
        if isinstance(value, dict):
22
            arg_dict[key] = Namespace(**value)
23
        else:
24
            arg_dict[key] = value
25
26
    return args
27
28
29
def merge_configs(config, resume_config):
30
    for key, value in resume_config.items():
31
        if isinstance(value, Namespace):
32
            value = value.__dict__
33
        if key in config and config[key] != value:
34
            warnings.warn(f"Config parameter '{key}' (value: "
35
                          f"{config[key]}) will be overwritten with value "
36
                          f"{value} from the checkpoint.")
37
        config[key] = value
38
    return config
39
40
41
# ------------------------------------------------------------------------------
42
# Training
43
# ______________________________________________________________________________
44
if __name__ == "__main__":
45
    p = argparse.ArgumentParser()
46
    p.add_argument('--config', type=str, required=True)
47
    p.add_argument('--resume', type=str, default=None)
48
    args = p.parse_args()
49
50
    with open(args.config, 'r') as f:
51
        config = yaml.safe_load(f)
52
53
    assert 'resume' not in config
54
55
    # Get main config
56
    ckpt_path = None if args.resume is None else Path(args.resume)
57
    if args.resume is not None:
58
        resume_config = torch.load(
59
            ckpt_path, map_location=torch.device('cpu'))['hyper_parameters']
60
61
        config = merge_configs(config, resume_config)
62
63
    args = merge_args_and_yaml(args, config)
64
65
    out_dir = Path(args.logdir, args.run_name)
66
    histogram_file = Path(args.datadir, 'size_distribution.npy')
67
    histogram = np.load(histogram_file).tolist()
68
    pl_module = LigandPocketDDPM(
69
        outdir=out_dir,
70
        dataset=args.dataset,
71
        datadir=args.datadir,
72
        batch_size=args.batch_size,
73
        lr=args.lr,
74
        egnn_params=args.egnn_params,
75
        diffusion_params=args.diffusion_params,
76
        num_workers=args.num_workers,
77
        augment_noise=args.augment_noise,
78
        augment_rotation=args.augment_rotation,
79
        clip_grad=args.clip_grad,
80
        eval_epochs=args.eval_epochs,
81
        eval_params=args.eval_params,
82
        visualize_sample_epoch=args.visualize_sample_epoch,
83
        visualize_chain_epoch=args.visualize_chain_epoch,
84
        auxiliary_loss=args.auxiliary_loss,
85
        loss_params=args.loss_params,
86
        mode=args.mode,
87
        node_histogram=histogram,
88
        pocket_representation=args.pocket_representation,
89
        virtual_nodes=args.virtual_nodes
90
    )
91
92
    logger = pl.loggers.WandbLogger(
93
        save_dir=args.logdir,
94
        project='ligand-pocket-ddpm',
95
        group=args.wandb_params.group,
96
        name=args.run_name,
97
        id=args.run_name,
98
        resume='must' if args.resume is not None else False,
99
        entity=args.wandb_params.entity,
100
        mode=args.wandb_params.mode,
101
    )
102
103
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
104
        dirpath=Path(out_dir, 'checkpoints'),
105
        filename="best-model-epoch={epoch:02d}",
106
        monitor="loss/val",
107
        save_top_k=1,
108
        save_last=True,
109
        mode="min",
110
    )
111
112
    trainer = pl.Trainer(
113
        max_epochs=args.n_epochs,
114
        logger=logger,
115
        callbacks=[checkpoint_callback],
116
        enable_progress_bar=args.enable_progress_bar,
117
        num_sanity_val_steps=args.num_sanity_val_steps,
118
        accelerator='gpu', devices=args.gpus,
119
        strategy=('ddp' if args.gpus > 1 else None)
120
    )
121
122
    trainer.fit(model=pl_module, ckpt_path=ckpt_path)