Switch to unified view

a b/src/evaluation/plots/barplots.py
1
# Base Dependencies
2
# -----------------
3
from pathlib import Path
4
from os.path import join as pjoin
5
6
# Local Dependencies
7
# ------------------
8
from evaluation.io import collect_step_times
9
10
# 3rd-Party Dependencies
11
# ----------------------
12
import matplotlib.pyplot as plt
13
import pandas as pd
14
import seaborn as sns
15
16
# Constants
17
# ---------
18
COLOR_PALETTE = "Set2"
19
20
21
# Auxiliar Functions
22
# ------------------
23
def rename_strategies(x):
24
    if x == "BatchLC" or x == "BatchBALD":
25
        return "BatchLC / BatchBALD"
26
    else:
27
        return x
28
29
30
# Main Functions
31
# --------------
32
def step_time_barplot():
33
    """Plots a strip plot showing the step time results for the DDI Extraction corpus"""
34
    title = "Average Step Time per Method and Query Strategy"
35
    output_file = Path(pjoin("results", "step_time_barplot.png"))
36
37
    # collect results
38
    n2c2_results = collect_step_times(Path(pjoin("results", "n2c2", "all")))
39
    ddi_results = collect_step_times(Path(pjoin("results", "ddi")))
40
    n2c2_results["Corpus"] = ["n2c2"] * len(n2c2_results)
41
    ddi_results["Corpus"] = ["DDI"] * len(ddi_results)
42
    results = pd.concat([n2c2_results, ddi_results], ignore_index=True)
43
44
    # convert step time to minutes
45
    results["iter_time (average)"] = results["iter_time (average)"].apply(
46
        lambda x: x / 60
47
    )
48
    results["strategy"] = results["strategy"].apply(rename_strategies)
49
50
    # plot
51
    sns.set()
52
    sns.set_style("whitegrid")  # grid style
53
    colors = sns.color_palette(COLOR_PALETTE)
54
    fig, axes = plt.subplots(2)
55
    g = sns.FacetGrid(results, row="Corpus", aspect=2, legend_out=True, sharex=False)
56
    g.map(sns.barplot, "iter_time (average)", "method", "strategy", palette=colors)
57
    g.add_legend(title="Query Strategy")
58
    g.set_ylabels("Method")
59
    g.set_xlabels("Step Time (minutes)")
60
    g.set_yticklabels(["RF", "BiLSTM", "CBERT", "CBERT-pairs"])
61
62
    sns.move_legend(
63
        g,
64
        "lower center",
65
        bbox_to_anchor=(0.45, 1),
66
        ncol=3,
67
        title=None,
68
        frameon=True,
69
    )
70
71
    # ax.bar_label(ax.containers[0], fmt='%.f%%')
72
73
    if title:
74
        # plt.title(title)
75
        pass
76
    if output_file:
77
        plt.savefig(output_file, bbox_inches="tight")
78
    else:
79
        plt.show()
80
81
    plt.clf()