--- a +++ b/src/utils/plotting.py @@ -0,0 +1,583 @@ +""" +Functions for plotting results and descriptive analysis of data. +""" + +import time +import json +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path +from datetime import datetime +from collections import defaultdict + +ROOT_DIR = Path(__file__).parents[2] +RESULTS_DIR = ROOT_DIR / "results" + +METRIC_FULL_NAME = { + "Top1_Acc": "Accuracy", + "BalAcc": "Balanced Accuracy", + "Loss": "Loss", +} + +STRATEGY_CATEGORY = { + "Naive": "Baseline", + "Cumulative": "Baseline", + "EWC": "Regularization", + "OnlineEWC": "Regularization", + "SI": "Regularization", + "LwF": "Regularization", + "Replay": "Rehearsal", + "GEM": "Rehearsal", + "AGEM": "Rehearsal", + "GDumb": "Rehearsal", +} + +STRATEGY_COLOURS = { + "Naive": "dodgerblue", + "Cumulative": "deepskyblue", + "EWC": "orange", + "OnlineEWC": "gold", + "SI": "tomato", + "LwF": "peru", + "Replay": "forestgreen", + "GEM": "limegreen", + "AGEM": "yellowgreen", + "GDumb": "palegreen", +} + + +def get_timestamp(): + """ + Returns current timestamp as string. + """ + ts = time.time() + return datetime.fromtimestamp(ts).strftime("%Y-%m-%d-%H-%M-%S") + + +################################### +# Plot figs (metrics over epoch) +################################### + + +def stack_results(results, metric, mode, type="experience"): + """ + Stacks results for multiple experiments along same axis in df. + + Either stacks: + - multiple experiences' metric for same model/strategy, or + - multiple strategies' [avg/stream] metrics for same model + """ + + results_dfs = [] + + # Get metrics for each training "experience"'s test set + n_repeats = len(results) + + for i in range(n_repeats): + metric_dict = defaultdict(list) + for k, v in results[i].items(): + if f"{metric}_Exp/eval_phase/{mode}_stream" in k: + new_k = ( + k.split("/")[-1].replace("Exp00", "Task ").replace("Exp0", "Task ") + ) + metric_dict[new_k] = v[1] + + df = pd.DataFrame.from_dict(metric_dict) + df.index.rename("Epoch", inplace=True) + stacked = df.stack().reset_index() + stacked.rename( + columns={"level_1": "Task", 0: METRIC_FULL_NAME[metric]}, inplace=True + ) + + results_dfs.append(stacked) + + stacked = pd.concat(results_dfs, sort=False) + + return stacked + + +def stack_avg_results(results_strats, metric, mode): + """ + Stack avg results for multiple strategies across epoch. + """ + results_dfs = [] + + # Get metrics for each training "experience"'s test set + n_repeats = len(list(results_strats.values())[0]) + for i in range(n_repeats): + metric_dict = defaultdict(list) + + # Get avg (stream) metrics for each strategy + for strat, metrics in results_strats.items(): + for k, v in metrics[i].items(): + # if train stream in keys "BalancedAccuracy_On_Trained_Experiences" + if ( + f"{METRIC_FULL_NAME[metric].replace(' ', '')}_On_Trained_Experiences/eval_phase/{mode}_stream" + in k + ): + # JA: early stopping means uneven length arrays. Must subsample at n_tasks + metric_dict[strat] = v[1] + break + elif f"{metric}_Stream/eval_phase/{mode}_stream" in k: + metric_dict[strat] = v[1] + + df = pd.DataFrame.from_dict(metric_dict) + df.index.rename("Epoch", inplace=True) + stacked = df.stack().reset_index() + stacked.rename( + columns={"level_1": "Strategy", 0: METRIC_FULL_NAME[metric]}, inplace=True + ) + + results_dfs.append(stacked) + + stacked = pd.concat(results_dfs, sort=False) + + return stacked + + +def plot_metric(method, model, results, mode, metric, ax=None): + """ + Plots given metric from dict. + Stacks multiple plots (i.e. different per-task metrics) over training time. + + `mode`: ['train','test'] (which stream to plot) + """ + ax = ax or plt.gca() + + stacked = stack_results(results, metric, mode) + + # Only plot task accuracies after examples have been encountered + # JA: this len() etc will screw up when plotting CI's + tasks = stacked["Task"].str.split(" ", expand=True)[1].astype(int) + n_epochs_per_task = (stacked["Epoch"].max() + 1) // stacked["Task"].nunique() + stacked = stacked[tasks * n_epochs_per_task <= stacked["Epoch"].astype(int)] + + sns.lineplot(data=stacked, x="Epoch", y=METRIC_FULL_NAME[metric], hue="Task", ax=ax) + ax.set_title(method, size=10) + ax.set_ylabel(model) + ax.set_xlabel("") + + +def plot_avg_metric(model, results, mode, metric, ax=None): + """ + Plots given metric from dict. + Stacks multiple plots (i.e. different strategies' metrics) over training time. + + `mode`: ['train','test'] (which stream to plot) + """ + ax = ax or plt.gca() + + stacked = stack_avg_results(results, metric, mode) + + sns.lineplot( + data=stacked, + x="Epoch", + y=METRIC_FULL_NAME[metric], + hue="Strategy", + ax=ax, + palette=STRATEGY_COLOURS, + ) + ax.set_title("Average performance over all tasks", size=10) + ax.set_ylabel(model) + ax.set_xlabel("") + + +def barplot_avg_metric(model, results, mode, metric, ax=None): + ax = ax or plt.gca() + + stacked = stack_avg_results(results, metric, mode) + stacked = stacked[stacked["Epoch"] == stacked["Epoch"].max()] + + sns.barplot( + data=stacked, + x="Strategy", + y=METRIC_FULL_NAME[metric], + ax=ax, + palette=STRATEGY_COLOURS, + ) + ax.set_title("Final average performance over all tasks", size=10) + ax.set_xlabel("") + + +################################### +# Clean up plots +################################### + + +def clean_subplot(i, j, axes, metric): + """Removes top and rights spines, titles, legend. Fixes y limits.""" + ax = axes[i, j] + + ax.spines[["top", "right"]].set_visible(False) + + if i > 0: + ax.set_title("") + if i > 0 or j > 0: + try: + ax.get_legend().remove() + except AttributeError: + pass + + if metric == "Loss": + ylim = (0, 4) + elif metric == "BalAcc": + ylim = (0.5, 1) + plt.setp(axes, ylim=ylim) + else: + ylim = (0.5, 1) + + # plt.setp(axes, ylim=ylim) + + +def clean_plot(fig, axes, metric): + """Cleans all subpots. Removes duplicate legends.""" + for i in range(len(axes)): + for j in range(len(axes[0])): + clean_subplot(i, j, axes, metric) + + handles, labels = axes[0, 0].get_legend_handles_labels() + axes[0, 0].get_legend().remove() + fig.legend(handles, labels, loc="center right", title="Task") + + +def annotate_plot(fig, domain, outcome, metric): + """Adds x/y labels and suptitles.""" + fig.supxlabel("Epoch") + fig.supylabel(METRIC_FULL_NAME[metric], x=0) + + fig.suptitle( + f"Continual Learning model comparison \n" + f"Outcome: {outcome} | Domain Increment: {domain}", + y=1.1, + ) + + +################################### +# Decorating functions for plotting everything +################################### + + +def plot_all_model_strats(data, domain, outcome, mode, metric, timestamp, savefig=True): + """Pairplot of all models vs strategies.""" + + # Load results + with open( + RESULTS_DIR / f"results_{data}_{outcome}_{domain}.json", encoding="utf-8" + ) as handle: + res = json.load(handle) + + models = res.keys() + strategies = next(iter(res.values())).keys() + + n_rows = len(models) + n_cols = len(strategies) + + # Experience plots + fig, axes = plt.subplots( + n_rows, + n_cols, + sharex=True, + sharey=True, + figsize=(2 * 20 * 4 / n_cols, 20 * n_rows / n_cols), + squeeze=False, + dpi=250, + ) + + for i, model in enumerate(models): + for j, strategy in enumerate(strategies): + plot_metric(strategy, model, res[model][strategy], mode, metric, axes[i, j]) + + clean_plot(fig, axes, metric) + annotate_plot(fig, domain, outcome, metric) + + if savefig: + file_loc = RESULTS_DIR / "figs" / data / outcome / domain / timestamp / mode + file_loc.mkdir(parents=True, exist_ok=True) + plt.savefig(file_loc / f"Exp_{metric}.png") + + # Stream plots + fig, axes = plt.subplots( + n_rows, + 2, + sharex=False, + sharey=True, + figsize=(20, 20 * n_rows / n_cols), + squeeze=False, + dpi=250, + ) + + for i, model in enumerate(models): + plot_avg_metric(model, res[model], mode, metric, axes[i, 0]) + barplot_avg_metric(model, res[model], mode, metric, axes[i, 1]) + + clean_plot(fig, axes, metric) + annotate_plot(fig, domain, outcome, metric) + + if savefig: + file_loc = RESULTS_DIR / "figs" / data / outcome / domain / timestamp / mode + file_loc.mkdir(parents=True, exist_ok=True) + plt.savefig(file_loc / f"Stream_{metric}.png") + + +def results_to_latex(): + """Returns results in LaTeX format for paper tables.""" + raise NotImplementedError + + +def plot_all_figs(data, domain, outcome): + """Plots all results figs for paper.""" + timestamp = get_timestamp() + + for mode in ["train", "test"]: + for metric in ["Loss", "Top1_Acc", "BalAcc"]: + plot_all_model_strats(data, domain, outcome, mode, metric, timestamp) + + +##################### +# DESCRIPTIVE PLOTS +##################### + + +def plot_demographics(): + """ + Plots demographic information of eICU dataset. + """ + + df = pd.DataFrame() # data_processing.load_eicu(drop_dupes=True) + _, axes = plt.subplots(3, 2, sharey=True, figsize=(18, 18), squeeze=False) + + df["gender"].value_counts().plot.bar(ax=axes[0, 0], rot=0, title="Gender") + df["ethnicity"].value_counts().plot.bar(ax=axes[1, 0], rot=0, title="Ethnicity") + df["ethnicity_coarse"].value_counts().plot.bar( + ax=axes[1, 1], rot=0, title="Ethnicity (coarse)" + ) + df["age"].plot.hist(bins=20, label="age", ax=axes[0, 1], title="Age") + df["region"].value_counts().plot.bar( + ax=axes[2, 0], rot=0, title="Region (North America)" + ) + df["hospitaldischargestatus"].value_counts().plot.bar( + ax=axes[2, 1], rot=0, title="Outcome" + ) + plt.show() + plt.close() + + +######################## +# LATEX TABLES +######################## + + +def ci_bound(std, count, ci=0.95): + """Return Confidence Interval radius.""" + return (1 + ci) * std / np.sqrt(count) + + +def results_to_table(data, domain, outcome, mode, metric, verbose=False, n="max"): + """Pairplot of all models vs strategies.""" + + # Load results + with open( + RESULTS_DIR / f"results_{data}_{outcome}_{domain}.json", encoding="utf-8" + ) as handle: + res = json.load(handle) + + models = [k for k in res.keys() if k in ["MLP", "CNN", "LSTM", "Transformer"]] + dfs = [] + + for model in models: + df = stack_avg_results(res[model], metric, mode) + df["Model"] = model + dfs.append(df) + + df = pd.concat(dfs) + + # Get final performance val + if n == "max": + df = df[df["Epoch"] == df["Epoch"].max()] + domain_col = domain + else: + df = df[df["Epoch"] == n] + domain_col = f"{domain} ({n})" + + stats = df.groupby(["Model", "Strategy"])[METRIC_FULL_NAME[metric]].agg( + ["mean", "count", "std"] + ) + + stats["ci95"] = ci_bound(stats["std"], stats["count"]) + + if verbose: + stats["ci95_lo"] = stats["mean"] + stats["ci95"] + stats["ci95_hi"] = stats["mean"] - stats["ci95"] + stats[domain_col] = stats.apply( + lambda x: f"{x['mean']:.3f} ({x.ci95_lo:.3f}, {x.ci95_hi:.3f})", axis=1 + ) + else: + stats[domain_col] = stats.apply( + lambda x: f"{100 * x['mean']:.1f}$_{{\pm{100 * x.ci95:.1f}}}$", axis=1 + ) + + stats = pd.DataFrame(stats[domain_col]) + stats.reset_index(inplace=True) + + stats["Category"] = stats["Strategy"].apply(lambda x: STRATEGY_CATEGORY[x]) + stats = stats.pivot(["Category", "Strategy"], "Model") + + return stats + + +def generate_table_results( + data="mimic3", outcome="mortality_48h", mode="test", metric="BalAcc", latex=False +): + """ + Latex table of main results + """ + domains = ["age", "ethnicity_coarse", "ward", "time_season"] + dfs = [] + + for domain in domains: + try: + dfs.append(results_to_table(data, domain, outcome, mode, metric)) + except: + pass + + df = pd.concat(dfs, axis=1) + + if latex: + idx = pd.IndexSlice + sub_idx = idx["Regularization":"Rehearsal", :] + df = df.style.highlight_max( + axis=0, + props="bfseries: ;", + subset=sub_idx, + ).to_latex() + return df + else: + return df + + +def generate_hp_table_super(outcome="mortality_48h"): + """ + Combines all tables into a nice latex format. + """ + + prefix = r""" +\begin{table}[h] +\centering + +""" + + box_prefix = r""" +\begin{adjustbox}{max width=\columnwidth} + +""" + old = r"""\begin{tabular}{lllllll}""" + repl = r"""\begin{tabular}{lllllll} +\multicolumn{7}{c}{\textsc{Age}} \\ + +""" + box_suffix = r""" +\end{adjustbox} + +""" + suffix = rf""" +\caption{{Tuned hyperparameters for main experiments (outcome of {outcome}).}} +\label{{tab:hyperparameters}} +\end{{table}} + +""" + + latex = ( + prefix + + box_prefix + + generate_hp_table(outcome=outcome, domain="age").to_latex().replace(old, repl) + + generate_hp_table(outcome=outcome, domain="ethnicity_coarse") + .to_latex() + .replace(old, repl.replace("Age", "Ethnicity (broad)")) + + box_suffix + + box_prefix + + generate_hp_table(outcome=outcome, domain="time_season") + .to_latex() + .replace(old, repl.replace("Age", "Time (season)")) + + generate_hp_table(outcome=outcome, domain="ward") + .to_latex() + .replace(old, repl.replace("Age", "ICU Ward")) + + box_suffix + + suffix + ) + + return latex + + +def generate_table_hospitals( + outcome="ARF_4h", + mode="test", + metric="BalAcc", + hospitals=[6, 12, 18, 24, 30, 36], + latex=False, +): + """ + Latex table of main results + """ + + dfs = [ + results_to_table("eicu", "hospital", outcome, mode, metric, n=n) + for n in hospitals + ] + + df = pd.concat(dfs, axis=1) + + if latex: + idx = pd.IndexSlice + sub_idx = idx["Regularization":"Rehearsal", :] + df = df.style.highlight_max( + axis=0, + props="bfseries: ;", + subset=sub_idx, + ).to_latex() + return df + else: + return df + + +def generate_hp_table(data="mimic3", outcome="mortality_48h", domain="age"): + models = ["MLP", "CNN", "LSTM", "Transformer"] + strategies = ["EWC", "OnlineEWC", "LwF", "SI", "Replay", "AGEM", "GEM"] + dfs = [] + col_rename_map = { + "ewc_lambda": "lambda", + "alpha": "lambda", + "si_lambda": "lambda", + "memory_strength": "temperature", + "mem_size": "sample_size", + } + + for model in models: + for strategy in strategies: + try: + with open( + ROOT_DIR + / "config" + / data + / outcome + / domain + / f"config_{model}_{strategy}.json", + encoding="utf-8", + ) as handle: + res = json.load(handle)["strategy"] + + df = pd.DataFrame([res]).rename(columns=col_rename_map) + df["Model"] = model + df["Strategy"] = strategy + + dfs.append(df) + except: + pass + df = pd.concat(dfs) + df = df.set_index(["Model", "Strategy"]) + df = df.replace(np.NaN, "") + df = df.drop("mode", axis=1) + + return df