Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery

Switch to side-by-side view

--- a
+++ b/src/move/visualization/latent_space.py
@@ -0,0 +1,223 @@
+__all__ = ["plot_latent_space_with_cat", "plot_latent_space_with_con"]
+
+from typing import Any
+
+import matplotlib.cm as cm
+import matplotlib.figure
+import matplotlib.pyplot as plt
+import matplotlib.style
+import numpy as np
+from matplotlib.colors import Normalize, TwoSlopeNorm
+
+from move.core.typing import BoolArray, FloatArray
+from move.visualization.style import (
+    DEFAULT_DIVERGING_PALETTE,
+    DEFAULT_PLOT_STYLE,
+    DEFAULT_QUALITATIVE_PALETTE,
+    color_cycle,
+    style_settings,
+)
+
+
+def plot_latent_space_with_cat(
+    latent_space: FloatArray,
+    feature_name: str,
+    feature_values: FloatArray,
+    feature_mapping: dict[str, Any],
+    is_nan: BoolArray,
+    style: str = DEFAULT_PLOT_STYLE,
+    colormap: str = DEFAULT_QUALITATIVE_PALETTE,
+) -> matplotlib.figure.Figure:
+    """Plot a 2D latent space together with a legend mapping the latent
+    space to the values of a discrete feature.
+
+    Args:
+        latent_space:
+            Embedding, a ND array with at least two dimensions.
+        feature_name:
+            Name of categorical feature
+        feature_values:
+            Values of categorical feature
+        feature_mapping:
+            Mapping of codes to categories for the categorical feature
+        is_nan:
+            Array of bool values indicating which feature values are NaNs
+        style:
+            Name of style to apply to the plot
+        colormap:
+            Name of qualitative colormap to use for each category
+
+    Raises:
+        ValueError: If latent space does not have at least two dimensions.
+
+    Returns:
+        Figure
+    """
+    if latent_space.ndim < 2:
+        raise ValueError("Expected at least two dimensions in latent space.")
+    with style_settings(style), color_cycle(colormap):
+        fig, ax = plt.subplots()
+        codes = np.unique(feature_values)
+        for code in codes:
+            category = feature_mapping[str(code)]
+            is_category = (feature_values == code) & ~is_nan
+            dims = np.take(latent_space.compress(is_category, axis=0), [0, 1], axis=1).T
+            ax.scatter(*dims, label=category)
+        dims = np.take(latent_space.compress(is_nan, axis=0), [0, 1], axis=1).T
+        ax.scatter(*dims, label="NaN")
+        ax.set(xlabel="dim 0", ylabel="dim 1")
+        legend = ax.legend()
+        legend.set_title(feature_name)
+    return fig
+
+
+def plot_latent_space_with_con(
+    latent_space: FloatArray,
+    feature_name: str,
+    feature_values: FloatArray,
+    style: str = DEFAULT_PLOT_STYLE,
+    colormap: str = DEFAULT_DIVERGING_PALETTE,
+) -> matplotlib.figure.Figure:
+    """Plot a 2D latent space together with a colorbar mapping the latent
+    space to the values of a continuous feature.
+
+    Args:
+        latent_space: Embedding, a ND array with at least two dimensions.
+        feature_name: Name of continuous feature
+        feature_values: Values of continuous feature
+        style: Name of style to apply to the plot
+        colormap: Name of colormap to use for the colorbar
+
+    Raises:
+        ValueError: If latent space does not have at least two dimensions.
+
+    Returns:
+        Figure
+    """
+    if latent_space.ndim < 2:
+        raise ValueError("Expected at least two dimensions in latent space.")
+    norm = TwoSlopeNorm(0.0, min(feature_values), max(feature_values))
+    with style_settings(style):
+        fig, ax = plt.subplots()
+        dims = latent_space[:, 0], latent_space[:, 1]
+        pts = ax.scatter(*dims, c=feature_values, cmap=colormap, norm=norm)
+        cbar = fig.colorbar(pts, ax=ax)
+        cbar.ax.set(ylabel=feature_name)
+        ax.set(xlabel="dim 0", ylabel="dim 1")
+    return fig
+
+
+def plot_3D_latent_and_displacement(
+    mu_baseline: FloatArray,
+    mu_perturbed: FloatArray,
+    feature_values: FloatArray,
+    feature_name: str,
+    show_baseline: bool = True,
+    show_perturbed: bool = True,
+    show_arrows: bool = True,
+    step: int = 1,
+    altitude: int = 30,
+    azimuth: int = 45,
+) -> matplotlib.figure.Figure:
+    """
+    Plot the movement of the samples in the 3D latent space after perturbing one
+    input variable.
+
+    Args:
+        mu_baseline:
+            ND array with dimensions n_samples x n_latent_nodes containing
+            the latent representation of each sample
+        mu_perturbed:
+            ND array with dimensions n_samples x n_latent_nodes containing
+            the latent representation of each sample after perturbing the input
+        feature_values:
+            1D array with feature values to map to a colormap ("bwr"). Each sample is
+            colored according to its value for the feature of interest.
+        feature_name:
+            name of the feature mapped to a colormap
+        show_baseline:
+            plot orginal location of the samples in the latent space
+        show_perturbed:
+            plot final location (after perturbation) of the samples in latent space
+        show_arrows:
+            plot arrows from original to final location of each sample
+        angle:
+            elevation from dim1-dim2 plane for the visualization of latent space.
+
+    Raises:
+        ValueError: If latent space is not 3-dimensional (3 hidden nodes).
+    Returns:
+        Figure
+    """
+    if [np.shape(mu_baseline)[1], np.shape(mu_perturbed)[1]] != [3, 3]:
+        raise ValueError(
+            " The latent space must be 3-dimensional. Redefine num_latent to 3."
+        )
+
+    fig = plt.figure(layout="constrained", figsize=(10, 10))
+    ax = fig.add_subplot(projection="3d")
+    ax.view_init(altitude, azimuth)
+
+    if show_baseline:
+        # vmin, vmax = np.min(feature_values[::step]), np.max(feature_values[::step])
+        # abs_max = np.max([abs(vmin), abs(vmax)])
+        ax.scatter(
+            mu_baseline[::step, 0],
+            mu_baseline[::step, 1],
+            mu_baseline[::step, 2],
+            marker=".",
+            c=feature_values[::step],
+            s=10,
+            lw=0,
+            cmap="seismic",
+            vmin=-2,
+            vmax=2,
+        )
+        ax.set_title(feature_name)
+        fig.colorbar(
+            cm.ScalarMappable(cmap="seismic", norm=Normalize(-2, 2)), ax=ax
+        )  # Normalize(min(feature_values[::step]),max(feature_values[::step]))), ax=ax)
+    if show_perturbed:
+        ax.scatter(
+            mu_perturbed[::step, 0],
+            mu_perturbed[::step, 1],
+            mu_perturbed[::step, 2],
+            marker=".",
+            color="lightblue",
+            label="perturbed",
+            lw=0.5,
+        )
+    if show_arrows:
+        u = mu_perturbed[::step, 0] - mu_baseline[::step, 0]
+        v = mu_perturbed[::step, 1] - mu_baseline[::step, 1]
+        w = mu_perturbed[::step, 2] - mu_baseline[::step, 2]
+
+        # module = np.sqrt(u * u + v * v + w * w)
+
+        max_u, max_v, max_w = np.max(abs(u)), np.max(abs(v)), np.max(abs(w))
+        # Arrow colors will be weighted contributions of
+        # red -> dim1,
+        # green -> dim2,
+        # and blue-> dim3.
+        # I.e. purple arrow means movement in dims 1 and 3
+        colors = [
+            (abs(du) / max_u, abs(dv) / max_v, abs(dw) / max_w, 0.7)
+            for du, dv, dw in zip(u, v, w)
+        ]
+        ax.quiver(
+            mu_baseline[::step, 0],
+            mu_baseline[::step, 1],
+            mu_baseline[::step, 2],
+            u,
+            v,
+            w,
+            color=colors,
+            lw=0.8,
+        )  # alpha=(1-module/np.max(module))**6, arrow_length_ratio=0)
+        # help(ax.quiver)
+    ax.set_xlabel("Dim 1")
+    ax.set_ylabel("Dim 2")
+    ax.set_zlabel("Dim 3")
+    # ax.set_axis_off()
+
+    return fig