|
a |
|
b/src/move/visualization/metrics.py |
|
|
1 |
__all__ = ["plot_metrics_boxplot"] |
|
|
2 |
|
|
|
3 |
from collections.abc import Sequence |
|
|
4 |
|
|
|
5 |
import matplotlib |
|
|
6 |
import matplotlib.figure |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
|
|
|
9 |
from move.core.typing import FloatArray |
|
|
10 |
from move.visualization.style import ( |
|
|
11 |
DEFAULT_PLOT_STYLE, |
|
|
12 |
DEFAULT_QUALITATIVE_PALETTE, |
|
|
13 |
color_cycle, |
|
|
14 |
style_settings, |
|
|
15 |
) |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def plot_metrics_boxplot( |
|
|
19 |
scores: Sequence[FloatArray], |
|
|
20 |
labels: Sequence[str], |
|
|
21 |
style: str = DEFAULT_PLOT_STYLE, |
|
|
22 |
colormap: str = DEFAULT_QUALITATIVE_PALETTE, |
|
|
23 |
) -> matplotlib.figure.Figure: |
|
|
24 |
"""Plot a box plot, showing the distribution of metrics per dataset. Each |
|
|
25 |
score corresponds (for example) to a sample. |
|
|
26 |
|
|
|
27 |
Args: |
|
|
28 |
scores: List of dataset metrics |
|
|
29 |
labels: List of dataset names |
|
|
30 |
style: Name of style to apply to the plot |
|
|
31 |
colormap: Name of colormap to use for the boxes |
|
|
32 |
|
|
|
33 |
Returns: |
|
|
34 |
Figure |
|
|
35 |
""" |
|
|
36 |
with style_settings(style), color_cycle(colormap): |
|
|
37 |
labelcolor = matplotlib.rcParams["axes.labelcolor"] |
|
|
38 |
fig, ax = plt.subplots() |
|
|
39 |
comps = ax.boxplot( |
|
|
40 |
scores[::-1], |
|
|
41 |
labels=labels[::-1], |
|
|
42 |
patch_artist=True, |
|
|
43 |
vert=False, |
|
|
44 |
capprops=dict(color=labelcolor), |
|
|
45 |
flierprops=dict( |
|
|
46 |
marker="d", |
|
|
47 |
markersize=5, |
|
|
48 |
markerfacecolor=labelcolor, |
|
|
49 |
markeredgecolor=labelcolor, |
|
|
50 |
), |
|
|
51 |
medianprops=dict(color=labelcolor), |
|
|
52 |
whiskerprops=dict(color=labelcolor), |
|
|
53 |
) |
|
|
54 |
prop_cycle = matplotlib.rcParams["axes.prop_cycle"] |
|
|
55 |
for box, prop in zip(comps["boxes"], prop_cycle()): |
|
|
56 |
box.update(dict(facecolor=prop["color"], edgecolor=labelcolor)) |
|
|
57 |
ax.set(xlim=(-0.05, 1.05), xlabel="Score", ylabel="Dataset") |
|
|
58 |
return fig |