|
a |
|
b/src/evaluation/plots/barplots.py |
|
|
1 |
# Base Dependencies |
|
|
2 |
# ----------------- |
|
|
3 |
from pathlib import Path |
|
|
4 |
from os.path import join as pjoin |
|
|
5 |
|
|
|
6 |
# Local Dependencies |
|
|
7 |
# ------------------ |
|
|
8 |
from evaluation.io import collect_step_times |
|
|
9 |
|
|
|
10 |
# 3rd-Party Dependencies |
|
|
11 |
# ---------------------- |
|
|
12 |
import matplotlib.pyplot as plt |
|
|
13 |
import pandas as pd |
|
|
14 |
import seaborn as sns |
|
|
15 |
|
|
|
16 |
# Constants |
|
|
17 |
# --------- |
|
|
18 |
COLOR_PALETTE = "Set2" |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
# Auxiliar Functions |
|
|
22 |
# ------------------ |
|
|
23 |
def rename_strategies(x): |
|
|
24 |
if x == "BatchLC" or x == "BatchBALD": |
|
|
25 |
return "BatchLC / BatchBALD" |
|
|
26 |
else: |
|
|
27 |
return x |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
# Main Functions |
|
|
31 |
# -------------- |
|
|
32 |
def step_time_barplot(): |
|
|
33 |
"""Plots a strip plot showing the step time results for the DDI Extraction corpus""" |
|
|
34 |
title = "Average Step Time per Method and Query Strategy" |
|
|
35 |
output_file = Path(pjoin("results", "step_time_barplot.png")) |
|
|
36 |
|
|
|
37 |
# collect results |
|
|
38 |
n2c2_results = collect_step_times(Path(pjoin("results", "n2c2", "all"))) |
|
|
39 |
ddi_results = collect_step_times(Path(pjoin("results", "ddi"))) |
|
|
40 |
n2c2_results["Corpus"] = ["n2c2"] * len(n2c2_results) |
|
|
41 |
ddi_results["Corpus"] = ["DDI"] * len(ddi_results) |
|
|
42 |
results = pd.concat([n2c2_results, ddi_results], ignore_index=True) |
|
|
43 |
|
|
|
44 |
# convert step time to minutes |
|
|
45 |
results["iter_time (average)"] = results["iter_time (average)"].apply( |
|
|
46 |
lambda x: x / 60 |
|
|
47 |
) |
|
|
48 |
results["strategy"] = results["strategy"].apply(rename_strategies) |
|
|
49 |
|
|
|
50 |
# plot |
|
|
51 |
sns.set() |
|
|
52 |
sns.set_style("whitegrid") # grid style |
|
|
53 |
colors = sns.color_palette(COLOR_PALETTE) |
|
|
54 |
fig, axes = plt.subplots(2) |
|
|
55 |
g = sns.FacetGrid(results, row="Corpus", aspect=2, legend_out=True, sharex=False) |
|
|
56 |
g.map(sns.barplot, "iter_time (average)", "method", "strategy", palette=colors) |
|
|
57 |
g.add_legend(title="Query Strategy") |
|
|
58 |
g.set_ylabels("Method") |
|
|
59 |
g.set_xlabels("Step Time (minutes)") |
|
|
60 |
g.set_yticklabels(["RF", "BiLSTM", "CBERT", "CBERT-pairs"]) |
|
|
61 |
|
|
|
62 |
sns.move_legend( |
|
|
63 |
g, |
|
|
64 |
"lower center", |
|
|
65 |
bbox_to_anchor=(0.45, 1), |
|
|
66 |
ncol=3, |
|
|
67 |
title=None, |
|
|
68 |
frameon=True, |
|
|
69 |
) |
|
|
70 |
|
|
|
71 |
# ax.bar_label(ax.containers[0], fmt='%.f%%') |
|
|
72 |
|
|
|
73 |
if title: |
|
|
74 |
# plt.title(title) |
|
|
75 |
pass |
|
|
76 |
if output_file: |
|
|
77 |
plt.savefig(output_file, bbox_inches="tight") |
|
|
78 |
else: |
|
|
79 |
plt.show() |
|
|
80 |
|
|
|
81 |
plt.clf() |