Diff of /src/utils/training.py [000000] .. [66326d]

Switch to unified view

a b/src/utils/training.py
1
"""
2
Contains functions for running hyperparameter sweep and
3
Continual Learning model-training and evaluation.
4
"""
5
6
import json
7
import warnings
8
from pathlib import Path
9
from functools import partial
10
11
# import random
12
# import numpy as np
13
14
import torch
15
from ray import tune
16
from torch import nn, optim
17
18
from avalanche.logging import InteractiveLogger, TensorboardLogger
19
from avalanche.training.plugins import EvaluationPlugin
20
from avalanche.training.plugins.early_stopping import EarlyStoppingPlugin
21
from avalanche.evaluation.metrics import (
22
    accuracy_metrics,
23
    loss_metrics,
24
    StreamConfusionMatrix,
25
)
26
27
# Local imports
28
from utils import models, plotting, data_processing, cl_strategies
29
from utils.metrics import (
30
    balancedaccuracy_metrics,
31
    sensitivity_metrics,
32
    specificity_metrics,
33
    precision_metrics,
34
    rocauc_metrics,
35
    auprc_metrics,
36
)
37
38
# Suppressing erroneous MaxPool1d named tensors warning
39
warnings.filterwarnings("once", category=UserWarning)
40
41
# GLOBALS
42
RESULTS_DIR = Path(__file__).parents[1] / "results"
43
CONFIG_DIR = Path(__file__).parents[1] / "config"
44
CUDA = torch.cuda.is_available()
45
DEVICE = "cuda" if CUDA else "cpu"
46
47
# Reproducibility
48
SEED = 12345
49
# random.seed(SEED)
50
# np.random.seed(SEED)
51
torch.manual_seed(SEED)
52
53
54
def save_params(data, domain, outcome, model, strategy, best_params):
55
    """Save hyper-param config to json."""
56
57
    file_loc = CONFIG_DIR / data / outcome / domain
58
    file_loc.mkdir(parents=True, exist_ok=True)
59
60
    with open(
61
        file_loc / f"config_{model}_{strategy}.json", "w", encoding="utf-8"
62
    ) as json_file:
63
        json.dump(best_params, json_file)
64
65
66
def load_params(data, domain, outcome, model, strategy):
67
    """Load hyper-param config from json."""
68
69
    file_loc = CONFIG_DIR / data / outcome / domain
70
71
    with open(
72
        file_loc / f"config_{model}_{strategy}.json", encoding="utf-8"
73
    ) as json_file:
74
        best_params = json.load(json_file)
75
    return best_params
76
77
78
def save_results(data, outcome, domain, res):
79
    """Saves results to .json (excluding tensor confusion matrix)."""
80
    with open(
81
        RESULTS_DIR / f"results_{data}_{outcome}_{domain}.json", "w", encoding="utf-8"
82
    ) as handle:
83
        res_no_tensors = {
84
            m: {
85
                s: [
86
                    {
87
                        metric: value
88
                        for metric, value in run.items()
89
                        if "Confusion" not in metric
90
                    }
91
                    for run in runs
92
                ]
93
                for s, runs in strats.items()
94
            }
95
            for m, strats in res.items()
96
        }
97
        json.dump(res_no_tensors, handle)
98
99
100
def load_strategy(
101
    model,
102
    model_name,
103
    strategy_name,
104
    data="",
105
    domain="",
106
    n_tasks=0,
107
    weight=None,
108
    validate=False,
109
    config=None,
110
    benchmark=None,
111
    early_stopping=False,
112
):
113
    """
114
    - `stream`     Avg accuracy over all experiences (may rely on tasks being roughly same size?)
115
    - `experience` Accuracy for each experience
116
    """
117
118
    strategy = cl_strategies.STRATEGIES[strategy_name]
119
    criterion = nn.CrossEntropyLoss(weight=weight)
120
121
    if config["generic"]["optimizer"] == "SGD":
122
        optimizer = optim.SGD(
123
            model.parameters(), lr=config["generic"]["lr"], momentum=0.9
124
        )
125
    elif config["generic"]["optimizer"] == "Adam":
126
        optimizer = optim.Adam(model.parameters(), lr=config["generic"]["lr"])
127
128
    if validate:
129
        loggers = []
130
    else:
131
        timestamp = plotting.get_timestamp()
132
        log_dir = (
133
            RESULTS_DIR
134
            / "log"
135
            / "tensorboard"
136
            / f"{data}_{domain}_{timestamp}"
137
            / model_name
138
            / strategy_name
139
        )
140
        interactive_logger = InteractiveLogger()
141
        tb_logger = TensorboardLogger(tb_log_dir=log_dir)
142
        loggers = [interactive_logger, tb_logger]
143
144
    eval_plugin = EvaluationPlugin(
145
        StreamConfusionMatrix(save_image=False),
146
        loss_metrics(stream=True, experience=not validate),
147
        accuracy_metrics(trained_experience=True, stream=True, experience=not validate),
148
        balancedaccuracy_metrics(
149
            trained_experience=True, stream=True, experience=not validate
150
        ),
151
        specificity_metrics(
152
            trained_experience=True, stream=True, experience=not validate
153
        ),
154
        sensitivity_metrics(
155
            trained_experience=True, stream=True, experience=not validate
156
        ),
157
        precision_metrics(
158
            trained_experience=True, stream=True, experience=not validate
159
        ),
160
        # rocauc_metrics(trained_experience=True, stream=True, experience=not validate),
161
        # auprc_metrics(trained_experience=True, stream=True, experience=not validate),
162
        loggers=loggers,
163
        benchmark=benchmark,
164
    )
165
166
    if early_stopping:
167
        early_stopping = EarlyStoppingPlugin(
168
            patience=5,
169
            val_stream_name="train_stream/Task000",
170
            metric_name="BalancedAccuracy_On_Trained_Experiences",
171
        )
172
        plugins = [early_stopping]
173
    else:
174
        plugins = None
175
176
    if strategy_name == "Joint":
177
        eval_every = None
178
179
    cl_strategy = strategy(
180
        model,
181
        optimizer=optimizer,
182
        device=DEVICE,
183
        criterion=criterion,
184
        eval_mb_size=1024,
185
        eval_every=0,  # if validate or n_tasks > 5 else 1,
186
        evaluator=eval_plugin,
187
        train_epochs=15,
188
        train_mb_size=config["generic"]["train_mb_size"],
189
        plugins=plugins,
190
        **config["strategy"],
191
    )
192
193
    return cl_strategy
194
195
196
def train_cl_method(cl_strategy, scenario, strategy_name, validate=False):
197
    """
198
    Avalanche Cl training loop. For each 'experience' in scenario's train_stream:
199
200
        - Trains method on experience
201
        - evaluates model on train_stream and test_stream
202
    """
203
    if not validate:
204
        print("Starting experiment...")
205
206
    if strategy_name == "Joint":
207
        if not validate:
208
            print(f"Joint training:")
209
        cl_strategy.train(
210
            scenario.train_stream,
211
            eval_streams=[scenario.train_stream, scenario.test_stream],
212
        )
213
        if not validate:
214
            print("Training completed", "\n\n")
215
216
    else:
217
        for experience in scenario.train_stream:
218
            if not validate:
219
                print(
220
                    f"{strategy_name} - Start of experience: {experience.current_experience}"
221
                )
222
            cl_strategy.train(
223
                experience, eval_streams=[scenario.train_stream, scenario.test_stream]
224
            )
225
            if not validate:
226
                print("Training completed", "\n\n")
227
228
    if validate:
229
        return cl_strategy.evaluator.get_last_metrics()
230
    else:
231
        return cl_strategy.evaluator.get_all_metrics()
232
233
234
def training_loop(
235
    config,
236
    data,
237
    domain,
238
    outcome,
239
    model_name,
240
    strategy_name,
241
    validate=False,
242
    checkpoint_dir=None,
243
):
244
    """
245
    Training wrapper:
246
        - loads data
247
        - instantiates model
248
        - equips model with CL strategy
249
        - trains and evaluates method
250
        - returns either results or hyperparam optimisation if `validate`
251
    """
252
253
    # Loading data into 'stream' of 'experiences' (tasks)
254
    if not validate:
255
        print("Loading data...")
256
    scenario, n_tasks, n_timesteps, n_channels, weight = data_processing.load_data(
257
        data, domain, outcome, validate
258
    )
259
    if weight is not None:
260
        weight = weight.to(DEVICE)
261
    if not validate:
262
        print("Data loaded.\n")
263
    if not validate:
264
        print(f"N timesteps: {n_timesteps}\nN features:  {n_channels}")
265
266
    model = models.MODELS[model_name](n_channels, n_timesteps, **config["model"])
267
    cl_strategy = load_strategy(
268
        model,
269
        model_name,
270
        strategy_name,
271
        data,
272
        domain,
273
        n_tasks=n_tasks,
274
        weight=weight,
275
        validate=validate,
276
        config=config,
277
        benchmark=scenario,
278
    )
279
    results = train_cl_method(cl_strategy, scenario, strategy_name, validate=validate)
280
281
    if validate:
282
        loss = results["Loss_Stream/eval_phase/test_stream/Task000"]
283
        accuracy = results[
284
            "Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000"
285
        ]
286
        balancedaccuracy = results[
287
            "BalancedAccuracy_On_Trained_Experiences/eval_phase/test_stream/Task000"
288
        ]
289
        # sensitivity = results['Sens_Stream/eval_phase/test_stream/Task000']
290
        # specificity = results['Spec_Stream/eval_phase/test_stream/Task000']
291
        # precision = results['Prec_Stream/eval_phase/test_stream/Task000']
292
        # rocauc = results['ROCAUC_Stream/eval_phase/test_stream/Task000']
293
        # auprc = results['AUPRC_Stream/eval_phase/test_stream/Task000']
294
295
        # WARNING: `return` overwrites raytune report
296
        tune.report(
297
            loss=loss,
298
            accuracy=accuracy,
299
            balancedaccuracy=balancedaccuracy,
300
            # auprc=auprc,
301
            # rocauc=rocauc
302
        )
303
304
    else:
305
        return results
306
307
308
def hyperparam_opt(
309
    config, data, domain, outcome, model_name, strategy_name, num_samples
310
):
311
    """
312
    Hyperparameter optimisation for the given model/strategy.
313
    Runs over the validation data for the first 2 tasks.
314
    """
315
316
    reporter = tune.CLIReporter(
317
        metric_columns=[
318
            "loss",
319
            "accuracy",
320
            "balancedaccuracy",
321
            #'auprc',
322
            #'rocauc'
323
        ]
324
    )
325
    resources = {"cpu": 4, "gpu": 0.5} if CUDA else {"cpu": 1}
326
327
    result = tune.run(
328
        partial(
329
            training_loop,
330
            data=data,
331
            domain=domain,
332
            outcome=outcome,
333
            model_name=model_name,
334
            strategy_name=strategy_name,
335
            validate=True,
336
        ),
337
        config=config,
338
        num_samples=num_samples,
339
        progress_reporter=reporter,
340
        raise_on_failed_trial=False,
341
        resources_per_trial=resources,
342
        name=f"{model_name}_{strategy_name}",
343
        local_dir=RESULTS_DIR / "log" / "raytune" / f"{data}_{outcome}_{domain}",
344
        trial_name_creator=lambda t: f"{model_name}_{strategy_name}_{t.trial_id}",
345
    )
346
347
    best_trial = result.get_best_trial("balancedaccuracy", "max", "last")
348
    print(f"Best trial config:                             {best_trial.config}")
349
    print(
350
        f"Best trial final validation loss:              {best_trial.last_result['loss']}"
351
    )
352
    print(
353
        f"Best trial final validation accuracy:          {best_trial.last_result['accuracy']}"
354
    )
355
    print(
356
        f"Best trial final validation balanced accuracy: {best_trial.last_result['balancedaccuracy']}"
357
    )
358
359
    return best_trial.config
360
361
362
def main(
363
    data,
364
    domain,
365
    outcome,
366
    models,
367
    strategies,
368
    dropout=False,
369
    config_generic={},
370
    config_model={},
371
    config_cl={},
372
    validate=False,
373
    num_samples=50,
374
    freeze_model_hp=False,
375
):
376
    """
377
    Main training loop. Defines dataset given outcome/domain
378
    and evaluates model/strategies over given hyperparams over this problem.
379
    """
380
381
    # Container for metrics results
382
    res = {m: {s: [] for s in strategies} for m in models}
383
384
    for model in models:
385
        for strategy in strategies:
386
            # Garbage collection
387
            torch.cuda.empty_cache()
388
389
            if validate:  # Hyperparam opt over first 2 tasks
390
                # Load generic tuned hyper-params
391
                if strategy == "Naive" or not freeze_model_hp:
392
                    config = {
393
                        "generic": config_generic,
394
                        "model": config_model[model],
395
                        "strategy": config_cl.get(strategy, {}),
396
                    }
397
                else:
398
                    naive_params = load_params(data, domain, outcome, model, "Naive")
399
                    config = {
400
                        "generic": naive_params["generic"],
401
                        "model": naive_params["model"],
402
                        "strategy": config_cl.get(strategy, {}),
403
                    }
404
405
                # JA: Investigate adding dropout to CNN (final FC layers only?)
406
                if not dropout and model != "CNN":
407
                    config["model"]["dropout"] = 0
408
409
                best_params = hyperparam_opt(
410
                    config,
411
                    data,
412
                    domain,
413
                    outcome,
414
                    model,
415
                    strategy,
416
                    num_samples=1 if strategy == "Naive" else num_samples,
417
                )
418
                save_params(data, domain, outcome, model, strategy, best_params)
419
420
            else:  # Training loop over all tasks
421
                config = load_params(data, domain, outcome, model, strategy)
422
423
                # Multiple runs for Confidence Intervals
424
                n_repeats = 1
425
                for _ in range(n_repeats):
426
                    curr_results = training_loop(
427
                        config, data, domain, outcome, model, strategy
428
                    )
429
                    res[model][strategy].append(curr_results)
430
431
    if not validate:
432
        save_results(data, outcome, domain, res)
433
        plotting.plot_all_figs(data, domain, outcome)