Diff of /src/run.py [000000] .. [10ee32]

Switch to unified view

a b/src/run.py
1
import argparse
2
import torch
3
import datetime
4
import os
5
import numpy as np
6
from dateutil import tz
7
from omegaconf import OmegaConf
8
from pytorch_lightning import seed_everything
9
from pytorch_lightning import loggers as pl_loggers
10
from pytorch_lightning.trainer import Trainer
11
from pytorch_lightning.callbacks import (
12
    ModelCheckpoint,
13
    EarlyStopping,
14
    LearningRateMonitor,
15
)
16
17
torch.backends.cudnn.deterministic = True
18
torch.backends.cudnn.benchmark = True
19
20
def get_parser():
21
    parser = argparse.ArgumentParser()
22
    parser.add_argument(
23
        "-c",
24
        "--config",
25
        metavar="base_config.yaml",
26
        help="paths to base config",
27
        required=True,
28
    )
29
    parser.add_argument(
30
        "--train", action="store_true", default=False, help="specify to train model"
31
    )
32
    parser.add_argument(
33
        "--test",
34
        action="store_true",
35
        default=False,
36
        help="specify to test model"
37
        "By default run.py trains a model based on config file",
38
    )
39
    parser.add_argument(
40
        "--ckpt_path", type=str, default=None, help="Checkpoint path for the save model"
41
    )
42
    parser.add_argument("--random_seed", type=int, default=23, help="Random seed")
43
    parser.add_argument(
44
        "--train_pct", type=float, default=1.0, help="Percent of training data"
45
    )
46
    parser.add_argument(
47
        "--splits",
48
        type=int,
49
        default=1,
50
        help="Train on n number of splits used for training. Defaults to 1",
51
    )
52
    parser = Trainer.add_argparse_args(parser)
53
54
    return parser
55
56
57
def main(cfg, args):
58
59
    # get datamodule
60
    dm = MLRL.builder.build_data_module(cfg)
61
    # define lightning module
62
    model = MLRL.builder.build_lightning_model(cfg, dm)
63
    # callbacks
64
    callbacks = [LearningRateMonitor(logging_interval="step")]
65
    if "checkpoint_callback" in cfg.lightning:
66
        checkpoint_callback = ModelCheckpoint(**cfg.lightning.checkpoint_callback)
67
        callbacks.append(checkpoint_callback)
68
    if "early_stopping_callback" in cfg.lightning:
69
        early_stopping_callback = EarlyStopping(**cfg.lightning.early_stopping_callback)
70
        callbacks.append(early_stopping_callback)
71
    if cfg.train.scheduler is not None:
72
        lr_monitor = LearningRateMonitor(logging_interval="step")
73
        callbacks.append(lr_monitor)
74
75
    # logging
76
    if "logger" in cfg.lightning:
77
        logger_type = cfg.lightning.logger.pop("logger_type")
78
        logger_class = getattr(pl_loggers, logger_type)
79
        cfg.lightning.logger.name = f"{cfg.experiment_name}_{cfg.extension}"
80
        logger = logger_class(**cfg.lightning.logger)
81
        cfg.lightning.logger.logger_type = logger_type
82
    else:
83
        logger = None
84
85
    # setup pytorch-lightning trainer
86
    cfg.lightning.trainer.val_check_interval = args.val_check_interval
87
    cfg.lightning.trainer.auto_lr_find = args.auto_lr_find
88
    trainer_args = argparse.Namespace(**cfg.lightning.trainer)
89
    if cfg.lightning.ckpt!='None':
90
        trainer = Trainer.from_argparse_args(
91
            args=trainer_args, deterministic=True, callbacks=callbacks, logger=logger,  resume_from_checkpoint = cfg.lightning.ckpt
92
        )
93
    else:
94
        trainer = Trainer.from_argparse_args(
95
            args=trainer_args, deterministic=True, callbacks=callbacks, logger=logger
96
        )
97
98
    # learning rate finder
99
    if trainer_args.auto_lr_find is not False:
100
        lr_finder = trainer.tuner.lr_find(model, datamodule=dm)
101
        new_lr = lr_finder.suggestion()
102
        model.lr = new_lr
103
        print("=" * 80 + f"\nLearning rate updated to {new_lr}\n" + "=" * 80)
104
105
    if args.train:
106
        trainer.fit(model, dm)
107
    if args.test:
108
        trainer.test(model=model, datamodule=dm)
109
110
    # save top weights paths to yaml
111
    if "checkpoint_callback" in cfg.lightning:
112
        ckpt_paths = os.path.join(
113
            cfg.lightning.checkpoint_callback.dirpath, "best_ckpts.yaml"
114
        )
115
        checkpoint_callback.to_yaml(filepath=ckpt_paths)
116
117
118
if __name__ == "__main__":
119
120
    # parse arguments
121
    parser = get_parser()
122
    args = parser.parse_args()
123
    cfg = OmegaConf.load(args.config)
124
125
    # edit experiment name
126
    cfg.data.frac = args.train_pct
127
    # if cfg.trial_name is not None:
128
    #     cfg.experiment_name = f"{cfg.experiment_name}_{cfg.trial_name}"
129
    if args.splits is not None:
130
        cfg.experiment_name = f"{cfg.experiment_name}_{args.train_pct}"  # indicate % data used in trial name
131
132
    # loop over the number of independent training splits, defaults to 1 split
133
    for split in np.arange(args.splits):
134
135
        # get current time
136
        now = datetime.datetime.now(tz.tzlocal())
137
        timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
138
139
        # random seed
140
        args.random_seed = split + 1
141
        seed_everything(args.random_seed)
142
143
        # set directory names
144
        cfg.extension = str(args.random_seed) if args.splits != 1 else timestamp
145
        cfg.output_dir = cfg.lightning.logger.save_dir
146
        cfg.lightning.checkpoint_callback.dirpath = os.path.join(
147
            cfg.lightning.checkpoint_callback.dirpath,
148
            f"{cfg.experiment_name}/{cfg.extension}",
149
        )
150
151
        # create directories
152
        if not os.path.exists(cfg.lightning.logger.save_dir):
153
            os.makedirs(cfg.lightning.logger.save_dir)
154
        if not os.path.exists(cfg.lightning.checkpoint_callback.dirpath):
155
            os.makedirs(cfg.lightning.checkpoint_callback.dirpath)
156
        if not os.path.exists(cfg.output_dir):
157
            os.makedirs(cfg.output_dir)
158
159
        # save config
160
        config_path = os.path.join(cfg.output_dir, "config.yaml")
161
        with open(config_path, "w") as fp:
162
            OmegaConf.save(config=cfg, f=fp.name)
163
164
        main(cfg, args)