--- a
+++ b/src/utils/plotting.py
@@ -0,0 +1,583 @@
+"""
+Functions for plotting results and descriptive analysis of data.
+"""
+
+import time
+import json
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+from pathlib import Path
+from datetime import datetime
+from collections import defaultdict
+
+ROOT_DIR = Path(__file__).parents[2]
+RESULTS_DIR = ROOT_DIR / "results"
+
+METRIC_FULL_NAME = {
+    "Top1_Acc": "Accuracy",
+    "BalAcc": "Balanced Accuracy",
+    "Loss": "Loss",
+}
+
+STRATEGY_CATEGORY = {
+    "Naive": "Baseline",
+    "Cumulative": "Baseline",
+    "EWC": "Regularization",
+    "OnlineEWC": "Regularization",
+    "SI": "Regularization",
+    "LwF": "Regularization",
+    "Replay": "Rehearsal",
+    "GEM": "Rehearsal",
+    "AGEM": "Rehearsal",
+    "GDumb": "Rehearsal",
+}
+
+STRATEGY_COLOURS = {
+    "Naive": "dodgerblue",
+    "Cumulative": "deepskyblue",
+    "EWC": "orange",
+    "OnlineEWC": "gold",
+    "SI": "tomato",
+    "LwF": "peru",
+    "Replay": "forestgreen",
+    "GEM": "limegreen",
+    "AGEM": "yellowgreen",
+    "GDumb": "palegreen",
+}
+
+
+def get_timestamp():
+    """
+    Returns current timestamp as string.
+    """
+    ts = time.time()
+    return datetime.fromtimestamp(ts).strftime("%Y-%m-%d-%H-%M-%S")
+
+
+###################################
+# Plot figs (metrics over epoch)
+###################################
+
+
+def stack_results(results, metric, mode, type="experience"):
+    """
+    Stacks results for multiple experiments along same axis in df.
+
+    Either stacks:
+    - multiple experiences' metric for same model/strategy, or
+    - multiple strategies' [avg/stream] metrics for same model
+    """
+
+    results_dfs = []
+
+    # Get metrics for each training "experience"'s test set
+    n_repeats = len(results)
+
+    for i in range(n_repeats):
+        metric_dict = defaultdict(list)
+        for k, v in results[i].items():
+            if f"{metric}_Exp/eval_phase/{mode}_stream" in k:
+                new_k = (
+                    k.split("/")[-1].replace("Exp00", "Task ").replace("Exp0", "Task ")
+                )
+                metric_dict[new_k] = v[1]
+
+        df = pd.DataFrame.from_dict(metric_dict)
+        df.index.rename("Epoch", inplace=True)
+        stacked = df.stack().reset_index()
+        stacked.rename(
+            columns={"level_1": "Task", 0: METRIC_FULL_NAME[metric]}, inplace=True
+        )
+
+        results_dfs.append(stacked)
+
+    stacked = pd.concat(results_dfs, sort=False)
+
+    return stacked
+
+
+def stack_avg_results(results_strats, metric, mode):
+    """
+    Stack avg results for multiple strategies across epoch.
+    """
+    results_dfs = []
+
+    # Get metrics for each training "experience"'s test set
+    n_repeats = len(list(results_strats.values())[0])
+    for i in range(n_repeats):
+        metric_dict = defaultdict(list)
+
+        # Get avg (stream) metrics for each strategy
+        for strat, metrics in results_strats.items():
+            for k, v in metrics[i].items():
+                # if train stream in keys "BalancedAccuracy_On_Trained_Experiences"
+                if (
+                    f"{METRIC_FULL_NAME[metric].replace(' ', '')}_On_Trained_Experiences/eval_phase/{mode}_stream"
+                    in k
+                ):
+                    # JA: early stopping means uneven length arrays. Must subsample at n_tasks
+                    metric_dict[strat] = v[1]
+                    break
+                elif f"{metric}_Stream/eval_phase/{mode}_stream" in k:
+                    metric_dict[strat] = v[1]
+
+        df = pd.DataFrame.from_dict(metric_dict)
+        df.index.rename("Epoch", inplace=True)
+        stacked = df.stack().reset_index()
+        stacked.rename(
+            columns={"level_1": "Strategy", 0: METRIC_FULL_NAME[metric]}, inplace=True
+        )
+
+        results_dfs.append(stacked)
+
+    stacked = pd.concat(results_dfs, sort=False)
+
+    return stacked
+
+
+def plot_metric(method, model, results, mode, metric, ax=None):
+    """
+    Plots given metric from dict.
+    Stacks multiple plots (i.e. different per-task metrics) over training time.
+
+    `mode`: ['train','test'] (which stream to plot)
+    """
+    ax = ax or plt.gca()
+
+    stacked = stack_results(results, metric, mode)
+
+    # Only plot task accuracies after examples have been encountered
+    # JA: this len() etc will screw up when plotting CI's
+    tasks = stacked["Task"].str.split(" ", expand=True)[1].astype(int)
+    n_epochs_per_task = (stacked["Epoch"].max() + 1) // stacked["Task"].nunique()
+    stacked = stacked[tasks * n_epochs_per_task <= stacked["Epoch"].astype(int)]
+
+    sns.lineplot(data=stacked, x="Epoch", y=METRIC_FULL_NAME[metric], hue="Task", ax=ax)
+    ax.set_title(method, size=10)
+    ax.set_ylabel(model)
+    ax.set_xlabel("")
+
+
+def plot_avg_metric(model, results, mode, metric, ax=None):
+    """
+    Plots given metric from dict.
+    Stacks multiple plots (i.e. different strategies' metrics) over training time.
+
+    `mode`: ['train','test'] (which stream to plot)
+    """
+    ax = ax or plt.gca()
+
+    stacked = stack_avg_results(results, metric, mode)
+
+    sns.lineplot(
+        data=stacked,
+        x="Epoch",
+        y=METRIC_FULL_NAME[metric],
+        hue="Strategy",
+        ax=ax,
+        palette=STRATEGY_COLOURS,
+    )
+    ax.set_title("Average performance over all tasks", size=10)
+    ax.set_ylabel(model)
+    ax.set_xlabel("")
+
+
+def barplot_avg_metric(model, results, mode, metric, ax=None):
+    ax = ax or plt.gca()
+
+    stacked = stack_avg_results(results, metric, mode)
+    stacked = stacked[stacked["Epoch"] == stacked["Epoch"].max()]
+
+    sns.barplot(
+        data=stacked,
+        x="Strategy",
+        y=METRIC_FULL_NAME[metric],
+        ax=ax,
+        palette=STRATEGY_COLOURS,
+    )
+    ax.set_title("Final average performance over all tasks", size=10)
+    ax.set_xlabel("")
+
+
+###################################
+# Clean up plots
+###################################
+
+
+def clean_subplot(i, j, axes, metric):
+    """Removes top and rights spines, titles, legend. Fixes y limits."""
+    ax = axes[i, j]
+
+    ax.spines[["top", "right"]].set_visible(False)
+
+    if i > 0:
+        ax.set_title("")
+    if i > 0 or j > 0:
+        try:
+            ax.get_legend().remove()
+        except AttributeError:
+            pass
+
+    if metric == "Loss":
+        ylim = (0, 4)
+    elif metric == "BalAcc":
+        ylim = (0.5, 1)
+        plt.setp(axes, ylim=ylim)
+    else:
+        ylim = (0.5, 1)
+
+    # plt.setp(axes, ylim=ylim)
+
+
+def clean_plot(fig, axes, metric):
+    """Cleans all subpots. Removes duplicate legends."""
+    for i in range(len(axes)):
+        for j in range(len(axes[0])):
+            clean_subplot(i, j, axes, metric)
+
+    handles, labels = axes[0, 0].get_legend_handles_labels()
+    axes[0, 0].get_legend().remove()
+    fig.legend(handles, labels, loc="center right", title="Task")
+
+
+def annotate_plot(fig, domain, outcome, metric):
+    """Adds x/y labels and suptitles."""
+    fig.supxlabel("Epoch")
+    fig.supylabel(METRIC_FULL_NAME[metric], x=0)
+
+    fig.suptitle(
+        f"Continual Learning model comparison \n"
+        f"Outcome: {outcome} | Domain Increment: {domain}",
+        y=1.1,
+    )
+
+
+###################################
+# Decorating functions for plotting everything
+###################################
+
+
+def plot_all_model_strats(data, domain, outcome, mode, metric, timestamp, savefig=True):
+    """Pairplot of all models vs strategies."""
+
+    # Load results
+    with open(
+        RESULTS_DIR / f"results_{data}_{outcome}_{domain}.json", encoding="utf-8"
+    ) as handle:
+        res = json.load(handle)
+
+    models = res.keys()
+    strategies = next(iter(res.values())).keys()
+
+    n_rows = len(models)
+    n_cols = len(strategies)
+
+    # Experience plots
+    fig, axes = plt.subplots(
+        n_rows,
+        n_cols,
+        sharex=True,
+        sharey=True,
+        figsize=(2 * 20 * 4 / n_cols, 20 * n_rows / n_cols),
+        squeeze=False,
+        dpi=250,
+    )
+
+    for i, model in enumerate(models):
+        for j, strategy in enumerate(strategies):
+            plot_metric(strategy, model, res[model][strategy], mode, metric, axes[i, j])
+
+    clean_plot(fig, axes, metric)
+    annotate_plot(fig, domain, outcome, metric)
+
+    if savefig:
+        file_loc = RESULTS_DIR / "figs" / data / outcome / domain / timestamp / mode
+        file_loc.mkdir(parents=True, exist_ok=True)
+        plt.savefig(file_loc / f"Exp_{metric}.png")
+
+    # Stream plots
+    fig, axes = plt.subplots(
+        n_rows,
+        2,
+        sharex=False,
+        sharey=True,
+        figsize=(20, 20 * n_rows / n_cols),
+        squeeze=False,
+        dpi=250,
+    )
+
+    for i, model in enumerate(models):
+        plot_avg_metric(model, res[model], mode, metric, axes[i, 0])
+        barplot_avg_metric(model, res[model], mode, metric, axes[i, 1])
+
+    clean_plot(fig, axes, metric)
+    annotate_plot(fig, domain, outcome, metric)
+
+    if savefig:
+        file_loc = RESULTS_DIR / "figs" / data / outcome / domain / timestamp / mode
+        file_loc.mkdir(parents=True, exist_ok=True)
+        plt.savefig(file_loc / f"Stream_{metric}.png")
+
+
+def results_to_latex():
+    """Returns results in LaTeX format for paper tables."""
+    raise NotImplementedError
+
+
+def plot_all_figs(data, domain, outcome):
+    """Plots all results figs for paper."""
+    timestamp = get_timestamp()
+
+    for mode in ["train", "test"]:
+        for metric in ["Loss", "Top1_Acc", "BalAcc"]:
+            plot_all_model_strats(data, domain, outcome, mode, metric, timestamp)
+
+
+#####################
+# DESCRIPTIVE PLOTS
+#####################
+
+
+def plot_demographics():
+    """
+    Plots demographic information of eICU dataset.
+    """
+
+    df = pd.DataFrame()  # data_processing.load_eicu(drop_dupes=True)
+    _, axes = plt.subplots(3, 2, sharey=True, figsize=(18, 18), squeeze=False)
+
+    df["gender"].value_counts().plot.bar(ax=axes[0, 0], rot=0, title="Gender")
+    df["ethnicity"].value_counts().plot.bar(ax=axes[1, 0], rot=0, title="Ethnicity")
+    df["ethnicity_coarse"].value_counts().plot.bar(
+        ax=axes[1, 1], rot=0, title="Ethnicity (coarse)"
+    )
+    df["age"].plot.hist(bins=20, label="age", ax=axes[0, 1], title="Age")
+    df["region"].value_counts().plot.bar(
+        ax=axes[2, 0], rot=0, title="Region (North America)"
+    )
+    df["hospitaldischargestatus"].value_counts().plot.bar(
+        ax=axes[2, 1], rot=0, title="Outcome"
+    )
+    plt.show()
+    plt.close()
+
+
+########################
+# LATEX TABLES
+########################
+
+
+def ci_bound(std, count, ci=0.95):
+    """Return Confidence Interval radius."""
+    return (1 + ci) * std / np.sqrt(count)
+
+
+def results_to_table(data, domain, outcome, mode, metric, verbose=False, n="max"):
+    """Pairplot of all models vs strategies."""
+
+    # Load results
+    with open(
+        RESULTS_DIR / f"results_{data}_{outcome}_{domain}.json", encoding="utf-8"
+    ) as handle:
+        res = json.load(handle)
+
+    models = [k for k in res.keys() if k in ["MLP", "CNN", "LSTM", "Transformer"]]
+    dfs = []
+
+    for model in models:
+        df = stack_avg_results(res[model], metric, mode)
+        df["Model"] = model
+        dfs.append(df)
+
+    df = pd.concat(dfs)
+
+    # Get final performance val
+    if n == "max":
+        df = df[df["Epoch"] == df["Epoch"].max()]
+        domain_col = domain
+    else:
+        df = df[df["Epoch"] == n]
+        domain_col = f"{domain} ({n})"
+
+    stats = df.groupby(["Model", "Strategy"])[METRIC_FULL_NAME[metric]].agg(
+        ["mean", "count", "std"]
+    )
+
+    stats["ci95"] = ci_bound(stats["std"], stats["count"])
+
+    if verbose:
+        stats["ci95_lo"] = stats["mean"] + stats["ci95"]
+        stats["ci95_hi"] = stats["mean"] - stats["ci95"]
+        stats[domain_col] = stats.apply(
+            lambda x: f"{x['mean']:.3f} ({x.ci95_lo:.3f}, {x.ci95_hi:.3f})", axis=1
+        )
+    else:
+        stats[domain_col] = stats.apply(
+            lambda x: f"{100 * x['mean']:.1f}$_{{\pm{100 * x.ci95:.1f}}}$", axis=1
+        )
+
+    stats = pd.DataFrame(stats[domain_col])
+    stats.reset_index(inplace=True)
+
+    stats["Category"] = stats["Strategy"].apply(lambda x: STRATEGY_CATEGORY[x])
+    stats = stats.pivot(["Category", "Strategy"], "Model")
+
+    return stats
+
+
+def generate_table_results(
+    data="mimic3", outcome="mortality_48h", mode="test", metric="BalAcc", latex=False
+):
+    """
+    Latex table of main results
+    """
+    domains = ["age", "ethnicity_coarse", "ward", "time_season"]
+    dfs = []
+
+    for domain in domains:
+        try:
+            dfs.append(results_to_table(data, domain, outcome, mode, metric))
+        except:
+            pass
+
+    df = pd.concat(dfs, axis=1)
+
+    if latex:
+        idx = pd.IndexSlice
+        sub_idx = idx["Regularization":"Rehearsal", :]
+        df = df.style.highlight_max(
+            axis=0,
+            props="bfseries: ;",
+            subset=sub_idx,
+        ).to_latex()
+        return df
+    else:
+        return df
+
+
+def generate_hp_table_super(outcome="mortality_48h"):
+    """
+    Combines all tables into a nice latex format.
+    """
+
+    prefix = r"""
+\begin{table}[h]
+\centering
+
+"""
+
+    box_prefix = r"""
+\begin{adjustbox}{max width=\columnwidth}
+
+"""
+    old = r"""\begin{tabular}{lllllll}"""
+    repl = r"""\begin{tabular}{lllllll}
+\multicolumn{7}{c}{\textsc{Age}} \\
+
+"""
+    box_suffix = r"""
+\end{adjustbox}
+
+"""
+    suffix = rf"""
+\caption{{Tuned hyperparameters for main experiments (outcome of {outcome}).}}
+\label{{tab:hyperparameters}}
+\end{{table}}
+
+"""
+
+    latex = (
+        prefix
+        + box_prefix
+        + generate_hp_table(outcome=outcome, domain="age").to_latex().replace(old, repl)
+        + generate_hp_table(outcome=outcome, domain="ethnicity_coarse")
+        .to_latex()
+        .replace(old, repl.replace("Age", "Ethnicity (broad)"))
+        + box_suffix
+        + box_prefix
+        + generate_hp_table(outcome=outcome, domain="time_season")
+        .to_latex()
+        .replace(old, repl.replace("Age", "Time (season)"))
+        + generate_hp_table(outcome=outcome, domain="ward")
+        .to_latex()
+        .replace(old, repl.replace("Age", "ICU Ward"))
+        + box_suffix
+        + suffix
+    )
+
+    return latex
+
+
+def generate_table_hospitals(
+    outcome="ARF_4h",
+    mode="test",
+    metric="BalAcc",
+    hospitals=[6, 12, 18, 24, 30, 36],
+    latex=False,
+):
+    """
+    Latex table of main results
+    """
+
+    dfs = [
+        results_to_table("eicu", "hospital", outcome, mode, metric, n=n)
+        for n in hospitals
+    ]
+
+    df = pd.concat(dfs, axis=1)
+
+    if latex:
+        idx = pd.IndexSlice
+        sub_idx = idx["Regularization":"Rehearsal", :]
+        df = df.style.highlight_max(
+            axis=0,
+            props="bfseries: ;",
+            subset=sub_idx,
+        ).to_latex()
+        return df
+    else:
+        return df
+
+
+def generate_hp_table(data="mimic3", outcome="mortality_48h", domain="age"):
+    models = ["MLP", "CNN", "LSTM", "Transformer"]
+    strategies = ["EWC", "OnlineEWC", "LwF", "SI", "Replay", "AGEM", "GEM"]
+    dfs = []
+    col_rename_map = {
+        "ewc_lambda": "lambda",
+        "alpha": "lambda",
+        "si_lambda": "lambda",
+        "memory_strength": "temperature",
+        "mem_size": "sample_size",
+    }
+
+    for model in models:
+        for strategy in strategies:
+            try:
+                with open(
+                    ROOT_DIR
+                    / "config"
+                    / data
+                    / outcome
+                    / domain
+                    / f"config_{model}_{strategy}.json",
+                    encoding="utf-8",
+                ) as handle:
+                    res = json.load(handle)["strategy"]
+
+                df = pd.DataFrame([res]).rename(columns=col_rename_map)
+                df["Model"] = model
+                df["Strategy"] = strategy
+
+                dfs.append(df)
+            except:
+                pass
+    df = pd.concat(dfs)
+    df = df.set_index(["Model", "Strategy"])
+    df = df.replace(np.NaN, "")
+    df = df.drop("mode", axis=1)
+
+    return df