Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery

Switch to unified view

a b/src/move/visualization/loss_curves.py
1
__all__ = ["LOSS_LABELS", "plot_loss_curves"]
2
3
from collections.abc import Sequence
4
5
import matplotlib.figure
6
import matplotlib.pyplot as plt
7
import numpy as np
8
9
from move.visualization.style import (
10
    DEFAULT_PLOT_STYLE,
11
    DEFAULT_QUALITATIVE_PALETTE,
12
    color_cycle,
13
    style_settings,
14
)
15
16
LOSS_LABELS = ("Loss", "Cross-Entropy", "Sum of Squared Errors", "KLD")
17
18
19
def plot_loss_curves(
20
    losses: Sequence[list[float]],
21
    labels: Sequence[str] = LOSS_LABELS,
22
    style: str = DEFAULT_PLOT_STYLE,
23
    colormap: str = DEFAULT_QUALITATIVE_PALETTE,
24
) -> matplotlib.figure.Figure:
25
    """Plot one or more loss curves.
26
27
    Args:
28
        losses: List containing lists of loss values
29
        labels: List containing names of each loss line
30
        style: Name of style to apply to the plot
31
        colormap: Name of colormap to use for the curves
32
33
    Returns:
34
        Figure
35
    """
36
    num_epochs = len(losses[0])
37
    epochs = np.arange(num_epochs)
38
    with style_settings(style), color_cycle(colormap):
39
        fig, ax = plt.subplots()
40
        for loss, label in zip(losses, labels):
41
            ax.plot(epochs, loss, label=label, linestyle="-")
42
        ax.legend()
43
        ax.set(xlabel="Epochs", ylabel="Loss")
44
    return fig