--- a +++ b/main.py @@ -0,0 +1,142 @@ +""" +Main training script. +""" + +import argparse + +from src.utils import training +from src.config import config + + +def main(args): + """ + Runs appropriate experiment(s) from passed args. + """ + if args.models == "all": + args.models = ["MLP", "CNN", "LSTM", "Transformer"] + + if args.strategies == "all": + args.strategies = [ + "Naive", + "Cumulative", + "EWC", + "OnlineEWC", + "SI", + "LwF", + "Replay", + "GEM", + "AGEM", + ] + + # Hyperparam optimisation over validation data for first 2 tasks + if args.validate: + training.main( + data=args.data, + domain=args.domain_shift, + outcome=args.outcome, + models=args.models, + strategies=args.strategies, + dropout=args.dropout, + config_generic=config.config_generic, + config_model=config.config_model, + config_cl=config.config_cl, + num_samples=args.num_samples, + validate=True, + ) + + # Train and test over all tasks (using optimised hyperparams) + if args.train: + training.main( + data=args.data, + domain=args.domain_shift, + outcome=args.outcome, + models=args.models, + strategies=args.strategies, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data", + type=str, + default="mimic3", + choices=["mimic3", "eicu", "random"], + help="Dataset to use.", + ) + + parser.add_argument( + "--outcome", + type=str, + default="mortality_48h", + choices=["ARF_4h", "ARF_12h", "Shock_4h", "Shock_12h", "mortality_48h"], + help="Outcome to predict.", + ) + + parser.add_argument( + "--domain_shift", + type=str, + default="age", + choices=[ + "time_season", + "region", + "hospital", + "ward", + "age", + "sex", + "ethnicity", + "ethnicity_coarse", + ], + help="Domain shift exhibited in tasks.", + ) + + parser.add_argument( + "--strategies", + type=str, + default="all", + choices=[ + "Naive", + "Cumulative", + "Joint", + "EWC", + "OnlineEWC", + "SI", + "LwF", + "Replay", + "GDumb", + "GEM", + "AGEM", + ], + nargs="+", + help="Continual learning strategy(s) to evaluate.", + ) + + parser.add_argument( + "--models", + type=str, + default="all", + choices=["MLP", "CNN", "RNN", "LSTM", "GRU", "Transformer"], + nargs="+", + help="Model(s) to evaluate.", + ) + + parser.add_argument( + "--dropout", action="store_true", help="Add dropout to model(s)." + ) + + parser.add_argument("--validate", action="store_true", help="Tune hyperparameters.") + + parser.add_argument( + "--train", action="store_true", help="Train and test validated models." + ) + + parser.add_argument( + "--num_samples", + type=int, + default=1, + help="Number of samples to draw during hyperparameter search.", + ) + + args = parser.parse_args() + main(args)