|
a |
|
b/src/run.py |
|
|
1 |
import argparse |
|
|
2 |
import torch |
|
|
3 |
import datetime |
|
|
4 |
import os |
|
|
5 |
import numpy as np |
|
|
6 |
from dateutil import tz |
|
|
7 |
from omegaconf import OmegaConf |
|
|
8 |
from pytorch_lightning import seed_everything |
|
|
9 |
from pytorch_lightning import loggers as pl_loggers |
|
|
10 |
from pytorch_lightning.trainer import Trainer |
|
|
11 |
from pytorch_lightning.callbacks import ( |
|
|
12 |
ModelCheckpoint, |
|
|
13 |
EarlyStopping, |
|
|
14 |
LearningRateMonitor, |
|
|
15 |
) |
|
|
16 |
|
|
|
17 |
torch.backends.cudnn.deterministic = True |
|
|
18 |
torch.backends.cudnn.benchmark = True |
|
|
19 |
|
|
|
20 |
def get_parser(): |
|
|
21 |
parser = argparse.ArgumentParser() |
|
|
22 |
parser.add_argument( |
|
|
23 |
"-c", |
|
|
24 |
"--config", |
|
|
25 |
metavar="base_config.yaml", |
|
|
26 |
help="paths to base config", |
|
|
27 |
required=True, |
|
|
28 |
) |
|
|
29 |
parser.add_argument( |
|
|
30 |
"--train", action="store_true", default=False, help="specify to train model" |
|
|
31 |
) |
|
|
32 |
parser.add_argument( |
|
|
33 |
"--test", |
|
|
34 |
action="store_true", |
|
|
35 |
default=False, |
|
|
36 |
help="specify to test model" |
|
|
37 |
"By default run.py trains a model based on config file", |
|
|
38 |
) |
|
|
39 |
parser.add_argument( |
|
|
40 |
"--ckpt_path", type=str, default=None, help="Checkpoint path for the save model" |
|
|
41 |
) |
|
|
42 |
parser.add_argument("--random_seed", type=int, default=23, help="Random seed") |
|
|
43 |
parser.add_argument( |
|
|
44 |
"--train_pct", type=float, default=1.0, help="Percent of training data" |
|
|
45 |
) |
|
|
46 |
parser.add_argument( |
|
|
47 |
"--splits", |
|
|
48 |
type=int, |
|
|
49 |
default=1, |
|
|
50 |
help="Train on n number of splits used for training. Defaults to 1", |
|
|
51 |
) |
|
|
52 |
parser = Trainer.add_argparse_args(parser) |
|
|
53 |
|
|
|
54 |
return parser |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def main(cfg, args): |
|
|
58 |
|
|
|
59 |
# get datamodule |
|
|
60 |
dm = MLRL.builder.build_data_module(cfg) |
|
|
61 |
# define lightning module |
|
|
62 |
model = MLRL.builder.build_lightning_model(cfg, dm) |
|
|
63 |
# callbacks |
|
|
64 |
callbacks = [LearningRateMonitor(logging_interval="step")] |
|
|
65 |
if "checkpoint_callback" in cfg.lightning: |
|
|
66 |
checkpoint_callback = ModelCheckpoint(**cfg.lightning.checkpoint_callback) |
|
|
67 |
callbacks.append(checkpoint_callback) |
|
|
68 |
if "early_stopping_callback" in cfg.lightning: |
|
|
69 |
early_stopping_callback = EarlyStopping(**cfg.lightning.early_stopping_callback) |
|
|
70 |
callbacks.append(early_stopping_callback) |
|
|
71 |
if cfg.train.scheduler is not None: |
|
|
72 |
lr_monitor = LearningRateMonitor(logging_interval="step") |
|
|
73 |
callbacks.append(lr_monitor) |
|
|
74 |
|
|
|
75 |
# logging |
|
|
76 |
if "logger" in cfg.lightning: |
|
|
77 |
logger_type = cfg.lightning.logger.pop("logger_type") |
|
|
78 |
logger_class = getattr(pl_loggers, logger_type) |
|
|
79 |
cfg.lightning.logger.name = f"{cfg.experiment_name}_{cfg.extension}" |
|
|
80 |
logger = logger_class(**cfg.lightning.logger) |
|
|
81 |
cfg.lightning.logger.logger_type = logger_type |
|
|
82 |
else: |
|
|
83 |
logger = None |
|
|
84 |
|
|
|
85 |
# setup pytorch-lightning trainer |
|
|
86 |
cfg.lightning.trainer.val_check_interval = args.val_check_interval |
|
|
87 |
cfg.lightning.trainer.auto_lr_find = args.auto_lr_find |
|
|
88 |
trainer_args = argparse.Namespace(**cfg.lightning.trainer) |
|
|
89 |
if cfg.lightning.ckpt!='None': |
|
|
90 |
trainer = Trainer.from_argparse_args( |
|
|
91 |
args=trainer_args, deterministic=True, callbacks=callbacks, logger=logger, resume_from_checkpoint = cfg.lightning.ckpt |
|
|
92 |
) |
|
|
93 |
else: |
|
|
94 |
trainer = Trainer.from_argparse_args( |
|
|
95 |
args=trainer_args, deterministic=True, callbacks=callbacks, logger=logger |
|
|
96 |
) |
|
|
97 |
|
|
|
98 |
# learning rate finder |
|
|
99 |
if trainer_args.auto_lr_find is not False: |
|
|
100 |
lr_finder = trainer.tuner.lr_find(model, datamodule=dm) |
|
|
101 |
new_lr = lr_finder.suggestion() |
|
|
102 |
model.lr = new_lr |
|
|
103 |
print("=" * 80 + f"\nLearning rate updated to {new_lr}\n" + "=" * 80) |
|
|
104 |
|
|
|
105 |
if args.train: |
|
|
106 |
trainer.fit(model, dm) |
|
|
107 |
if args.test: |
|
|
108 |
trainer.test(model=model, datamodule=dm) |
|
|
109 |
|
|
|
110 |
# save top weights paths to yaml |
|
|
111 |
if "checkpoint_callback" in cfg.lightning: |
|
|
112 |
ckpt_paths = os.path.join( |
|
|
113 |
cfg.lightning.checkpoint_callback.dirpath, "best_ckpts.yaml" |
|
|
114 |
) |
|
|
115 |
checkpoint_callback.to_yaml(filepath=ckpt_paths) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
if __name__ == "__main__": |
|
|
119 |
|
|
|
120 |
# parse arguments |
|
|
121 |
parser = get_parser() |
|
|
122 |
args = parser.parse_args() |
|
|
123 |
cfg = OmegaConf.load(args.config) |
|
|
124 |
|
|
|
125 |
# edit experiment name |
|
|
126 |
cfg.data.frac = args.train_pct |
|
|
127 |
# if cfg.trial_name is not None: |
|
|
128 |
# cfg.experiment_name = f"{cfg.experiment_name}_{cfg.trial_name}" |
|
|
129 |
if args.splits is not None: |
|
|
130 |
cfg.experiment_name = f"{cfg.experiment_name}_{args.train_pct}" # indicate % data used in trial name |
|
|
131 |
|
|
|
132 |
# loop over the number of independent training splits, defaults to 1 split |
|
|
133 |
for split in np.arange(args.splits): |
|
|
134 |
|
|
|
135 |
# get current time |
|
|
136 |
now = datetime.datetime.now(tz.tzlocal()) |
|
|
137 |
timestamp = now.strftime("%Y_%m_%d_%H_%M_%S") |
|
|
138 |
|
|
|
139 |
# random seed |
|
|
140 |
args.random_seed = split + 1 |
|
|
141 |
seed_everything(args.random_seed) |
|
|
142 |
|
|
|
143 |
# set directory names |
|
|
144 |
cfg.extension = str(args.random_seed) if args.splits != 1 else timestamp |
|
|
145 |
cfg.output_dir = cfg.lightning.logger.save_dir |
|
|
146 |
cfg.lightning.checkpoint_callback.dirpath = os.path.join( |
|
|
147 |
cfg.lightning.checkpoint_callback.dirpath, |
|
|
148 |
f"{cfg.experiment_name}/{cfg.extension}", |
|
|
149 |
) |
|
|
150 |
|
|
|
151 |
# create directories |
|
|
152 |
if not os.path.exists(cfg.lightning.logger.save_dir): |
|
|
153 |
os.makedirs(cfg.lightning.logger.save_dir) |
|
|
154 |
if not os.path.exists(cfg.lightning.checkpoint_callback.dirpath): |
|
|
155 |
os.makedirs(cfg.lightning.checkpoint_callback.dirpath) |
|
|
156 |
if not os.path.exists(cfg.output_dir): |
|
|
157 |
os.makedirs(cfg.output_dir) |
|
|
158 |
|
|
|
159 |
# save config |
|
|
160 |
config_path = os.path.join(cfg.output_dir, "config.yaml") |
|
|
161 |
with open(config_path, "w") as fp: |
|
|
162 |
OmegaConf.save(config=cfg, f=fp.name) |
|
|
163 |
|
|
|
164 |
main(cfg, args) |