"""
Main training script.
"""
import argparse
from src.utils import training
from src.config import config
def main(args):
"""
Runs appropriate experiment(s) from passed args.
"""
if args.models == "all":
args.models = ["MLP", "CNN", "LSTM", "Transformer"]
if args.strategies == "all":
args.strategies = [
"Naive",
"Cumulative",
"EWC",
"OnlineEWC",
"SI",
"LwF",
"Replay",
"GEM",
"AGEM",
]
# Hyperparam optimisation over validation data for first 2 tasks
if args.validate:
training.main(
data=args.data,
domain=args.domain_shift,
outcome=args.outcome,
models=args.models,
strategies=args.strategies,
dropout=args.dropout,
config_generic=config.config_generic,
config_model=config.config_model,
config_cl=config.config_cl,
num_samples=args.num_samples,
validate=True,
)
# Train and test over all tasks (using optimised hyperparams)
if args.train:
training.main(
data=args.data,
domain=args.domain_shift,
outcome=args.outcome,
models=args.models,
strategies=args.strategies,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data",
type=str,
default="mimic3",
choices=["mimic3", "eicu", "random"],
help="Dataset to use.",
)
parser.add_argument(
"--outcome",
type=str,
default="mortality_48h",
choices=["ARF_4h", "ARF_12h", "Shock_4h", "Shock_12h", "mortality_48h"],
help="Outcome to predict.",
)
parser.add_argument(
"--domain_shift",
type=str,
default="age",
choices=[
"time_season",
"region",
"hospital",
"ward",
"age",
"sex",
"ethnicity",
"ethnicity_coarse",
],
help="Domain shift exhibited in tasks.",
)
parser.add_argument(
"--strategies",
type=str,
default="all",
choices=[
"Naive",
"Cumulative",
"Joint",
"EWC",
"OnlineEWC",
"SI",
"LwF",
"Replay",
"GDumb",
"GEM",
"AGEM",
],
nargs="+",
help="Continual learning strategy(s) to evaluate.",
)
parser.add_argument(
"--models",
type=str,
default="all",
choices=["MLP", "CNN", "RNN", "LSTM", "GRU", "Transformer"],
nargs="+",
help="Model(s) to evaluate.",
)
parser.add_argument(
"--dropout", action="store_true", help="Add dropout to model(s)."
)
parser.add_argument("--validate", action="store_true", help="Tune hyperparameters.")
parser.add_argument(
"--train", action="store_true", help="Train and test validated models."
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to draw during hyperparameter search.",
)
args = parser.parse_args()
main(args)