Diff of /main.py [000000] .. [66326d]

Switch to unified view

a b/main.py
1
"""
2
Main training script.
3
"""
4
5
import argparse
6
7
from src.utils import training
8
from src.config import config
9
10
11
def main(args):
12
    """
13
    Runs appropriate experiment(s) from passed args.
14
    """
15
    if args.models == "all":
16
        args.models = ["MLP", "CNN", "LSTM", "Transformer"]
17
18
    if args.strategies == "all":
19
        args.strategies = [
20
            "Naive",
21
            "Cumulative",
22
            "EWC",
23
            "OnlineEWC",
24
            "SI",
25
            "LwF",
26
            "Replay",
27
            "GEM",
28
            "AGEM",
29
        ]
30
31
    # Hyperparam optimisation over validation data for first 2 tasks
32
    if args.validate:
33
        training.main(
34
            data=args.data,
35
            domain=args.domain_shift,
36
            outcome=args.outcome,
37
            models=args.models,
38
            strategies=args.strategies,
39
            dropout=args.dropout,
40
            config_generic=config.config_generic,
41
            config_model=config.config_model,
42
            config_cl=config.config_cl,
43
            num_samples=args.num_samples,
44
            validate=True,
45
        )
46
47
    # Train and test over all tasks (using optimised hyperparams)
48
    if args.train:
49
        training.main(
50
            data=args.data,
51
            domain=args.domain_shift,
52
            outcome=args.outcome,
53
            models=args.models,
54
            strategies=args.strategies,
55
        )
56
57
58
if __name__ == "__main__":
59
    parser = argparse.ArgumentParser()
60
61
    parser.add_argument(
62
        "--data",
63
        type=str,
64
        default="mimic3",
65
        choices=["mimic3", "eicu", "random"],
66
        help="Dataset to use.",
67
    )
68
69
    parser.add_argument(
70
        "--outcome",
71
        type=str,
72
        default="mortality_48h",
73
        choices=["ARF_4h", "ARF_12h", "Shock_4h", "Shock_12h", "mortality_48h"],
74
        help="Outcome to predict.",
75
    )
76
77
    parser.add_argument(
78
        "--domain_shift",
79
        type=str,
80
        default="age",
81
        choices=[
82
            "time_season",
83
            "region",
84
            "hospital",
85
            "ward",
86
            "age",
87
            "sex",
88
            "ethnicity",
89
            "ethnicity_coarse",
90
        ],
91
        help="Domain shift exhibited in tasks.",
92
    )
93
94
    parser.add_argument(
95
        "--strategies",
96
        type=str,
97
        default="all",
98
        choices=[
99
            "Naive",
100
            "Cumulative",
101
            "Joint",
102
            "EWC",
103
            "OnlineEWC",
104
            "SI",
105
            "LwF",
106
            "Replay",
107
            "GDumb",
108
            "GEM",
109
            "AGEM",
110
        ],
111
        nargs="+",
112
        help="Continual learning strategy(s) to evaluate.",
113
    )
114
115
    parser.add_argument(
116
        "--models",
117
        type=str,
118
        default="all",
119
        choices=["MLP", "CNN", "RNN", "LSTM", "GRU", "Transformer"],
120
        nargs="+",
121
        help="Model(s) to evaluate.",
122
    )
123
124
    parser.add_argument(
125
        "--dropout", action="store_true", help="Add dropout to model(s)."
126
    )
127
128
    parser.add_argument("--validate", action="store_true", help="Tune hyperparameters.")
129
130
    parser.add_argument(
131
        "--train", action="store_true", help="Train and test validated models."
132
    )
133
134
    parser.add_argument(
135
        "--num_samples",
136
        type=int,
137
        default=1,
138
        help="Number of samples to draw during hyperparameter search.",
139
    )
140
141
    args = parser.parse_args()
142
    main(args)