|
a |
|
b/src/move/visualization/loss_curves.py |
|
|
1 |
__all__ = ["LOSS_LABELS", "plot_loss_curves"] |
|
|
2 |
|
|
|
3 |
from collections.abc import Sequence |
|
|
4 |
|
|
|
5 |
import matplotlib.figure |
|
|
6 |
import matplotlib.pyplot as plt |
|
|
7 |
import numpy as np |
|
|
8 |
|
|
|
9 |
from move.visualization.style import ( |
|
|
10 |
DEFAULT_PLOT_STYLE, |
|
|
11 |
DEFAULT_QUALITATIVE_PALETTE, |
|
|
12 |
color_cycle, |
|
|
13 |
style_settings, |
|
|
14 |
) |
|
|
15 |
|
|
|
16 |
LOSS_LABELS = ("Loss", "Cross-Entropy", "Sum of Squared Errors", "KLD") |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def plot_loss_curves( |
|
|
20 |
losses: Sequence[list[float]], |
|
|
21 |
labels: Sequence[str] = LOSS_LABELS, |
|
|
22 |
style: str = DEFAULT_PLOT_STYLE, |
|
|
23 |
colormap: str = DEFAULT_QUALITATIVE_PALETTE, |
|
|
24 |
) -> matplotlib.figure.Figure: |
|
|
25 |
"""Plot one or more loss curves. |
|
|
26 |
|
|
|
27 |
Args: |
|
|
28 |
losses: List containing lists of loss values |
|
|
29 |
labels: List containing names of each loss line |
|
|
30 |
style: Name of style to apply to the plot |
|
|
31 |
colormap: Name of colormap to use for the curves |
|
|
32 |
|
|
|
33 |
Returns: |
|
|
34 |
Figure |
|
|
35 |
""" |
|
|
36 |
num_epochs = len(losses[0]) |
|
|
37 |
epochs = np.arange(num_epochs) |
|
|
38 |
with style_settings(style), color_cycle(colormap): |
|
|
39 |
fig, ax = plt.subplots() |
|
|
40 |
for loss, label in zip(losses, labels): |
|
|
41 |
ax.plot(epochs, loss, label=label, linestyle="-") |
|
|
42 |
ax.legend() |
|
|
43 |
ax.set(xlabel="Epochs", ylabel="Loss") |
|
|
44 |
return fig |