--- a +++ b/src/utils/training.py @@ -0,0 +1,433 @@ +""" +Contains functions for running hyperparameter sweep and +Continual Learning model-training and evaluation. +""" + +import json +import warnings +from pathlib import Path +from functools import partial + +# import random +# import numpy as np + +import torch +from ray import tune +from torch import nn, optim + +from avalanche.logging import InteractiveLogger, TensorboardLogger +from avalanche.training.plugins import EvaluationPlugin +from avalanche.training.plugins.early_stopping import EarlyStoppingPlugin +from avalanche.evaluation.metrics import ( + accuracy_metrics, + loss_metrics, + StreamConfusionMatrix, +) + +# Local imports +from utils import models, plotting, data_processing, cl_strategies +from utils.metrics import ( + balancedaccuracy_metrics, + sensitivity_metrics, + specificity_metrics, + precision_metrics, + rocauc_metrics, + auprc_metrics, +) + +# Suppressing erroneous MaxPool1d named tensors warning +warnings.filterwarnings("once", category=UserWarning) + +# GLOBALS +RESULTS_DIR = Path(__file__).parents[1] / "results" +CONFIG_DIR = Path(__file__).parents[1] / "config" +CUDA = torch.cuda.is_available() +DEVICE = "cuda" if CUDA else "cpu" + +# Reproducibility +SEED = 12345 +# random.seed(SEED) +# np.random.seed(SEED) +torch.manual_seed(SEED) + + +def save_params(data, domain, outcome, model, strategy, best_params): + """Save hyper-param config to json.""" + + file_loc = CONFIG_DIR / data / outcome / domain + file_loc.mkdir(parents=True, exist_ok=True) + + with open( + file_loc / f"config_{model}_{strategy}.json", "w", encoding="utf-8" + ) as json_file: + json.dump(best_params, json_file) + + +def load_params(data, domain, outcome, model, strategy): + """Load hyper-param config from json.""" + + file_loc = CONFIG_DIR / data / outcome / domain + + with open( + file_loc / f"config_{model}_{strategy}.json", encoding="utf-8" + ) as json_file: + best_params = json.load(json_file) + return best_params + + +def save_results(data, outcome, domain, res): + """Saves results to .json (excluding tensor confusion matrix).""" + with open( + RESULTS_DIR / f"results_{data}_{outcome}_{domain}.json", "w", encoding="utf-8" + ) as handle: + res_no_tensors = { + m: { + s: [ + { + metric: value + for metric, value in run.items() + if "Confusion" not in metric + } + for run in runs + ] + for s, runs in strats.items() + } + for m, strats in res.items() + } + json.dump(res_no_tensors, handle) + + +def load_strategy( + model, + model_name, + strategy_name, + data="", + domain="", + n_tasks=0, + weight=None, + validate=False, + config=None, + benchmark=None, + early_stopping=False, +): + """ + - `stream` Avg accuracy over all experiences (may rely on tasks being roughly same size?) + - `experience` Accuracy for each experience + """ + + strategy = cl_strategies.STRATEGIES[strategy_name] + criterion = nn.CrossEntropyLoss(weight=weight) + + if config["generic"]["optimizer"] == "SGD": + optimizer = optim.SGD( + model.parameters(), lr=config["generic"]["lr"], momentum=0.9 + ) + elif config["generic"]["optimizer"] == "Adam": + optimizer = optim.Adam(model.parameters(), lr=config["generic"]["lr"]) + + if validate: + loggers = [] + else: + timestamp = plotting.get_timestamp() + log_dir = ( + RESULTS_DIR + / "log" + / "tensorboard" + / f"{data}_{domain}_{timestamp}" + / model_name + / strategy_name + ) + interactive_logger = InteractiveLogger() + tb_logger = TensorboardLogger(tb_log_dir=log_dir) + loggers = [interactive_logger, tb_logger] + + eval_plugin = EvaluationPlugin( + StreamConfusionMatrix(save_image=False), + loss_metrics(stream=True, experience=not validate), + accuracy_metrics(trained_experience=True, stream=True, experience=not validate), + balancedaccuracy_metrics( + trained_experience=True, stream=True, experience=not validate + ), + specificity_metrics( + trained_experience=True, stream=True, experience=not validate + ), + sensitivity_metrics( + trained_experience=True, stream=True, experience=not validate + ), + precision_metrics( + trained_experience=True, stream=True, experience=not validate + ), + # rocauc_metrics(trained_experience=True, stream=True, experience=not validate), + # auprc_metrics(trained_experience=True, stream=True, experience=not validate), + loggers=loggers, + benchmark=benchmark, + ) + + if early_stopping: + early_stopping = EarlyStoppingPlugin( + patience=5, + val_stream_name="train_stream/Task000", + metric_name="BalancedAccuracy_On_Trained_Experiences", + ) + plugins = [early_stopping] + else: + plugins = None + + if strategy_name == "Joint": + eval_every = None + + cl_strategy = strategy( + model, + optimizer=optimizer, + device=DEVICE, + criterion=criterion, + eval_mb_size=1024, + eval_every=0, # if validate or n_tasks > 5 else 1, + evaluator=eval_plugin, + train_epochs=15, + train_mb_size=config["generic"]["train_mb_size"], + plugins=plugins, + **config["strategy"], + ) + + return cl_strategy + + +def train_cl_method(cl_strategy, scenario, strategy_name, validate=False): + """ + Avalanche Cl training loop. For each 'experience' in scenario's train_stream: + + - Trains method on experience + - evaluates model on train_stream and test_stream + """ + if not validate: + print("Starting experiment...") + + if strategy_name == "Joint": + if not validate: + print(f"Joint training:") + cl_strategy.train( + scenario.train_stream, + eval_streams=[scenario.train_stream, scenario.test_stream], + ) + if not validate: + print("Training completed", "\n\n") + + else: + for experience in scenario.train_stream: + if not validate: + print( + f"{strategy_name} - Start of experience: {experience.current_experience}" + ) + cl_strategy.train( + experience, eval_streams=[scenario.train_stream, scenario.test_stream] + ) + if not validate: + print("Training completed", "\n\n") + + if validate: + return cl_strategy.evaluator.get_last_metrics() + else: + return cl_strategy.evaluator.get_all_metrics() + + +def training_loop( + config, + data, + domain, + outcome, + model_name, + strategy_name, + validate=False, + checkpoint_dir=None, +): + """ + Training wrapper: + - loads data + - instantiates model + - equips model with CL strategy + - trains and evaluates method + - returns either results or hyperparam optimisation if `validate` + """ + + # Loading data into 'stream' of 'experiences' (tasks) + if not validate: + print("Loading data...") + scenario, n_tasks, n_timesteps, n_channels, weight = data_processing.load_data( + data, domain, outcome, validate + ) + if weight is not None: + weight = weight.to(DEVICE) + if not validate: + print("Data loaded.\n") + if not validate: + print(f"N timesteps: {n_timesteps}\nN features: {n_channels}") + + model = models.MODELS[model_name](n_channels, n_timesteps, **config["model"]) + cl_strategy = load_strategy( + model, + model_name, + strategy_name, + data, + domain, + n_tasks=n_tasks, + weight=weight, + validate=validate, + config=config, + benchmark=scenario, + ) + results = train_cl_method(cl_strategy, scenario, strategy_name, validate=validate) + + if validate: + loss = results["Loss_Stream/eval_phase/test_stream/Task000"] + accuracy = results[ + "Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000" + ] + balancedaccuracy = results[ + "BalancedAccuracy_On_Trained_Experiences/eval_phase/test_stream/Task000" + ] + # sensitivity = results['Sens_Stream/eval_phase/test_stream/Task000'] + # specificity = results['Spec_Stream/eval_phase/test_stream/Task000'] + # precision = results['Prec_Stream/eval_phase/test_stream/Task000'] + # rocauc = results['ROCAUC_Stream/eval_phase/test_stream/Task000'] + # auprc = results['AUPRC_Stream/eval_phase/test_stream/Task000'] + + # WARNING: `return` overwrites raytune report + tune.report( + loss=loss, + accuracy=accuracy, + balancedaccuracy=balancedaccuracy, + # auprc=auprc, + # rocauc=rocauc + ) + + else: + return results + + +def hyperparam_opt( + config, data, domain, outcome, model_name, strategy_name, num_samples +): + """ + Hyperparameter optimisation for the given model/strategy. + Runs over the validation data for the first 2 tasks. + """ + + reporter = tune.CLIReporter( + metric_columns=[ + "loss", + "accuracy", + "balancedaccuracy", + #'auprc', + #'rocauc' + ] + ) + resources = {"cpu": 4, "gpu": 0.5} if CUDA else {"cpu": 1} + + result = tune.run( + partial( + training_loop, + data=data, + domain=domain, + outcome=outcome, + model_name=model_name, + strategy_name=strategy_name, + validate=True, + ), + config=config, + num_samples=num_samples, + progress_reporter=reporter, + raise_on_failed_trial=False, + resources_per_trial=resources, + name=f"{model_name}_{strategy_name}", + local_dir=RESULTS_DIR / "log" / "raytune" / f"{data}_{outcome}_{domain}", + trial_name_creator=lambda t: f"{model_name}_{strategy_name}_{t.trial_id}", + ) + + best_trial = result.get_best_trial("balancedaccuracy", "max", "last") + print(f"Best trial config: {best_trial.config}") + print( + f"Best trial final validation loss: {best_trial.last_result['loss']}" + ) + print( + f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}" + ) + print( + f"Best trial final validation balanced accuracy: {best_trial.last_result['balancedaccuracy']}" + ) + + return best_trial.config + + +def main( + data, + domain, + outcome, + models, + strategies, + dropout=False, + config_generic={}, + config_model={}, + config_cl={}, + validate=False, + num_samples=50, + freeze_model_hp=False, +): + """ + Main training loop. Defines dataset given outcome/domain + and evaluates model/strategies over given hyperparams over this problem. + """ + + # Container for metrics results + res = {m: {s: [] for s in strategies} for m in models} + + for model in models: + for strategy in strategies: + # Garbage collection + torch.cuda.empty_cache() + + if validate: # Hyperparam opt over first 2 tasks + # Load generic tuned hyper-params + if strategy == "Naive" or not freeze_model_hp: + config = { + "generic": config_generic, + "model": config_model[model], + "strategy": config_cl.get(strategy, {}), + } + else: + naive_params = load_params(data, domain, outcome, model, "Naive") + config = { + "generic": naive_params["generic"], + "model": naive_params["model"], + "strategy": config_cl.get(strategy, {}), + } + + # JA: Investigate adding dropout to CNN (final FC layers only?) + if not dropout and model != "CNN": + config["model"]["dropout"] = 0 + + best_params = hyperparam_opt( + config, + data, + domain, + outcome, + model, + strategy, + num_samples=1 if strategy == "Naive" else num_samples, + ) + save_params(data, domain, outcome, model, strategy, best_params) + + else: # Training loop over all tasks + config = load_params(data, domain, outcome, model, strategy) + + # Multiple runs for Confidence Intervals + n_repeats = 1 + for _ in range(n_repeats): + curr_results = training_loop( + config, data, domain, outcome, model, strategy + ) + res[model][strategy].append(curr_results) + + if not validate: + save_results(data, outcome, domain, res) + plotting.plot_all_figs(data, domain, outcome)