--- a +++ b/src/run.py @@ -0,0 +1,164 @@ +import argparse +import torch +import datetime +import os +import numpy as np +from dateutil import tz +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ( + ModelCheckpoint, + EarlyStopping, + LearningRateMonitor, +) + +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = True + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config", + metavar="base_config.yaml", + help="paths to base config", + required=True, + ) + parser.add_argument( + "--train", action="store_true", default=False, help="specify to train model" + ) + parser.add_argument( + "--test", + action="store_true", + default=False, + help="specify to test model" + "By default run.py trains a model based on config file", + ) + parser.add_argument( + "--ckpt_path", type=str, default=None, help="Checkpoint path for the save model" + ) + parser.add_argument("--random_seed", type=int, default=23, help="Random seed") + parser.add_argument( + "--train_pct", type=float, default=1.0, help="Percent of training data" + ) + parser.add_argument( + "--splits", + type=int, + default=1, + help="Train on n number of splits used for training. Defaults to 1", + ) + parser = Trainer.add_argparse_args(parser) + + return parser + + +def main(cfg, args): + + # get datamodule + dm = MLRL.builder.build_data_module(cfg) + # define lightning module + model = MLRL.builder.build_lightning_model(cfg, dm) + # callbacks + callbacks = [LearningRateMonitor(logging_interval="step")] + if "checkpoint_callback" in cfg.lightning: + checkpoint_callback = ModelCheckpoint(**cfg.lightning.checkpoint_callback) + callbacks.append(checkpoint_callback) + if "early_stopping_callback" in cfg.lightning: + early_stopping_callback = EarlyStopping(**cfg.lightning.early_stopping_callback) + callbacks.append(early_stopping_callback) + if cfg.train.scheduler is not None: + lr_monitor = LearningRateMonitor(logging_interval="step") + callbacks.append(lr_monitor) + + # logging + if "logger" in cfg.lightning: + logger_type = cfg.lightning.logger.pop("logger_type") + logger_class = getattr(pl_loggers, logger_type) + cfg.lightning.logger.name = f"{cfg.experiment_name}_{cfg.extension}" + logger = logger_class(**cfg.lightning.logger) + cfg.lightning.logger.logger_type = logger_type + else: + logger = None + + # setup pytorch-lightning trainer + cfg.lightning.trainer.val_check_interval = args.val_check_interval + cfg.lightning.trainer.auto_lr_find = args.auto_lr_find + trainer_args = argparse.Namespace(**cfg.lightning.trainer) + if cfg.lightning.ckpt!='None': + trainer = Trainer.from_argparse_args( + args=trainer_args, deterministic=True, callbacks=callbacks, logger=logger, resume_from_checkpoint = cfg.lightning.ckpt + ) + else: + trainer = Trainer.from_argparse_args( + args=trainer_args, deterministic=True, callbacks=callbacks, logger=logger + ) + + # learning rate finder + if trainer_args.auto_lr_find is not False: + lr_finder = trainer.tuner.lr_find(model, datamodule=dm) + new_lr = lr_finder.suggestion() + model.lr = new_lr + print("=" * 80 + f"\nLearning rate updated to {new_lr}\n" + "=" * 80) + + if args.train: + trainer.fit(model, dm) + if args.test: + trainer.test(model=model, datamodule=dm) + + # save top weights paths to yaml + if "checkpoint_callback" in cfg.lightning: + ckpt_paths = os.path.join( + cfg.lightning.checkpoint_callback.dirpath, "best_ckpts.yaml" + ) + checkpoint_callback.to_yaml(filepath=ckpt_paths) + + +if __name__ == "__main__": + + # parse arguments + parser = get_parser() + args = parser.parse_args() + cfg = OmegaConf.load(args.config) + + # edit experiment name + cfg.data.frac = args.train_pct + # if cfg.trial_name is not None: + # cfg.experiment_name = f"{cfg.experiment_name}_{cfg.trial_name}" + if args.splits is not None: + cfg.experiment_name = f"{cfg.experiment_name}_{args.train_pct}" # indicate % data used in trial name + + # loop over the number of independent training splits, defaults to 1 split + for split in np.arange(args.splits): + + # get current time + now = datetime.datetime.now(tz.tzlocal()) + timestamp = now.strftime("%Y_%m_%d_%H_%M_%S") + + # random seed + args.random_seed = split + 1 + seed_everything(args.random_seed) + + # set directory names + cfg.extension = str(args.random_seed) if args.splits != 1 else timestamp + cfg.output_dir = cfg.lightning.logger.save_dir + cfg.lightning.checkpoint_callback.dirpath = os.path.join( + cfg.lightning.checkpoint_callback.dirpath, + f"{cfg.experiment_name}/{cfg.extension}", + ) + + # create directories + if not os.path.exists(cfg.lightning.logger.save_dir): + os.makedirs(cfg.lightning.logger.save_dir) + if not os.path.exists(cfg.lightning.checkpoint_callback.dirpath): + os.makedirs(cfg.lightning.checkpoint_callback.dirpath) + if not os.path.exists(cfg.output_dir): + os.makedirs(cfg.output_dir) + + # save config + config_path = os.path.join(cfg.output_dir, "config.yaml") + with open(config_path, "w") as fp: + OmegaConf.save(config=cfg, f=fp.name) + + main(cfg, args)