a b/src/move/visualization/latent_space.py
1
__all__ = ["plot_latent_space_with_cat", "plot_latent_space_with_con"]
2
3
from typing import Any
4
5
import matplotlib.cm as cm
6
import matplotlib.figure
7
import matplotlib.pyplot as plt
8
import matplotlib.style
9
import numpy as np
10
from matplotlib.colors import Normalize, TwoSlopeNorm
11
12
from move.core.typing import BoolArray, FloatArray
13
from move.visualization.style import (
14
    DEFAULT_DIVERGING_PALETTE,
15
    DEFAULT_PLOT_STYLE,
16
    DEFAULT_QUALITATIVE_PALETTE,
17
    color_cycle,
18
    style_settings,
19
)
20
21
22
def plot_latent_space_with_cat(
23
    latent_space: FloatArray,
24
    feature_name: str,
25
    feature_values: FloatArray,
26
    feature_mapping: dict[str, Any],
27
    is_nan: BoolArray,
28
    style: str = DEFAULT_PLOT_STYLE,
29
    colormap: str = DEFAULT_QUALITATIVE_PALETTE,
30
) -> matplotlib.figure.Figure:
31
    """Plot a 2D latent space together with a legend mapping the latent
32
    space to the values of a discrete feature.
33
34
    Args:
35
        latent_space:
36
            Embedding, a ND array with at least two dimensions.
37
        feature_name:
38
            Name of categorical feature
39
        feature_values:
40
            Values of categorical feature
41
        feature_mapping:
42
            Mapping of codes to categories for the categorical feature
43
        is_nan:
44
            Array of bool values indicating which feature values are NaNs
45
        style:
46
            Name of style to apply to the plot
47
        colormap:
48
            Name of qualitative colormap to use for each category
49
50
    Raises:
51
        ValueError: If latent space does not have at least two dimensions.
52
53
    Returns:
54
        Figure
55
    """
56
    if latent_space.ndim < 2:
57
        raise ValueError("Expected at least two dimensions in latent space.")
58
    with style_settings(style), color_cycle(colormap):
59
        fig, ax = plt.subplots()
60
        codes = np.unique(feature_values)
61
        for code in codes:
62
            category = feature_mapping[str(code)]
63
            is_category = (feature_values == code) & ~is_nan
64
            dims = np.take(latent_space.compress(is_category, axis=0), [0, 1], axis=1).T
65
            ax.scatter(*dims, label=category)
66
        dims = np.take(latent_space.compress(is_nan, axis=0), [0, 1], axis=1).T
67
        ax.scatter(*dims, label="NaN")
68
        ax.set(xlabel="dim 0", ylabel="dim 1")
69
        legend = ax.legend()
70
        legend.set_title(feature_name)
71
    return fig
72
73
74
def plot_latent_space_with_con(
75
    latent_space: FloatArray,
76
    feature_name: str,
77
    feature_values: FloatArray,
78
    style: str = DEFAULT_PLOT_STYLE,
79
    colormap: str = DEFAULT_DIVERGING_PALETTE,
80
) -> matplotlib.figure.Figure:
81
    """Plot a 2D latent space together with a colorbar mapping the latent
82
    space to the values of a continuous feature.
83
84
    Args:
85
        latent_space: Embedding, a ND array with at least two dimensions.
86
        feature_name: Name of continuous feature
87
        feature_values: Values of continuous feature
88
        style: Name of style to apply to the plot
89
        colormap: Name of colormap to use for the colorbar
90
91
    Raises:
92
        ValueError: If latent space does not have at least two dimensions.
93
94
    Returns:
95
        Figure
96
    """
97
    if latent_space.ndim < 2:
98
        raise ValueError("Expected at least two dimensions in latent space.")
99
    norm = TwoSlopeNorm(0.0, min(feature_values), max(feature_values))
100
    with style_settings(style):
101
        fig, ax = plt.subplots()
102
        dims = latent_space[:, 0], latent_space[:, 1]
103
        pts = ax.scatter(*dims, c=feature_values, cmap=colormap, norm=norm)
104
        cbar = fig.colorbar(pts, ax=ax)
105
        cbar.ax.set(ylabel=feature_name)
106
        ax.set(xlabel="dim 0", ylabel="dim 1")
107
    return fig
108
109
110
def plot_3D_latent_and_displacement(
111
    mu_baseline: FloatArray,
112
    mu_perturbed: FloatArray,
113
    feature_values: FloatArray,
114
    feature_name: str,
115
    show_baseline: bool = True,
116
    show_perturbed: bool = True,
117
    show_arrows: bool = True,
118
    step: int = 1,
119
    altitude: int = 30,
120
    azimuth: int = 45,
121
) -> matplotlib.figure.Figure:
122
    """
123
    Plot the movement of the samples in the 3D latent space after perturbing one
124
    input variable.
125
126
    Args:
127
        mu_baseline:
128
            ND array with dimensions n_samples x n_latent_nodes containing
129
            the latent representation of each sample
130
        mu_perturbed:
131
            ND array with dimensions n_samples x n_latent_nodes containing
132
            the latent representation of each sample after perturbing the input
133
        feature_values:
134
            1D array with feature values to map to a colormap ("bwr"). Each sample is
135
            colored according to its value for the feature of interest.
136
        feature_name:
137
            name of the feature mapped to a colormap
138
        show_baseline:
139
            plot orginal location of the samples in the latent space
140
        show_perturbed:
141
            plot final location (after perturbation) of the samples in latent space
142
        show_arrows:
143
            plot arrows from original to final location of each sample
144
        angle:
145
            elevation from dim1-dim2 plane for the visualization of latent space.
146
147
    Raises:
148
        ValueError: If latent space is not 3-dimensional (3 hidden nodes).
149
    Returns:
150
        Figure
151
    """
152
    if [np.shape(mu_baseline)[1], np.shape(mu_perturbed)[1]] != [3, 3]:
153
        raise ValueError(
154
            " The latent space must be 3-dimensional. Redefine num_latent to 3."
155
        )
156
157
    fig = plt.figure(layout="constrained", figsize=(10, 10))
158
    ax = fig.add_subplot(projection="3d")
159
    ax.view_init(altitude, azimuth)
160
161
    if show_baseline:
162
        # vmin, vmax = np.min(feature_values[::step]), np.max(feature_values[::step])
163
        # abs_max = np.max([abs(vmin), abs(vmax)])
164
        ax.scatter(
165
            mu_baseline[::step, 0],
166
            mu_baseline[::step, 1],
167
            mu_baseline[::step, 2],
168
            marker=".",
169
            c=feature_values[::step],
170
            s=10,
171
            lw=0,
172
            cmap="seismic",
173
            vmin=-2,
174
            vmax=2,
175
        )
176
        ax.set_title(feature_name)
177
        fig.colorbar(
178
            cm.ScalarMappable(cmap="seismic", norm=Normalize(-2, 2)), ax=ax
179
        )  # Normalize(min(feature_values[::step]),max(feature_values[::step]))), ax=ax)
180
    if show_perturbed:
181
        ax.scatter(
182
            mu_perturbed[::step, 0],
183
            mu_perturbed[::step, 1],
184
            mu_perturbed[::step, 2],
185
            marker=".",
186
            color="lightblue",
187
            label="perturbed",
188
            lw=0.5,
189
        )
190
    if show_arrows:
191
        u = mu_perturbed[::step, 0] - mu_baseline[::step, 0]
192
        v = mu_perturbed[::step, 1] - mu_baseline[::step, 1]
193
        w = mu_perturbed[::step, 2] - mu_baseline[::step, 2]
194
195
        # module = np.sqrt(u * u + v * v + w * w)
196
197
        max_u, max_v, max_w = np.max(abs(u)), np.max(abs(v)), np.max(abs(w))
198
        # Arrow colors will be weighted contributions of
199
        # red -> dim1,
200
        # green -> dim2,
201
        # and blue-> dim3.
202
        # I.e. purple arrow means movement in dims 1 and 3
203
        colors = [
204
            (abs(du) / max_u, abs(dv) / max_v, abs(dw) / max_w, 0.7)
205
            for du, dv, dw in zip(u, v, w)
206
        ]
207
        ax.quiver(
208
            mu_baseline[::step, 0],
209
            mu_baseline[::step, 1],
210
            mu_baseline[::step, 2],
211
            u,
212
            v,
213
            w,
214
            color=colors,
215
            lw=0.8,
216
        )  # alpha=(1-module/np.max(module))**6, arrow_length_ratio=0)
217
        # help(ax.quiver)
218
    ax.set_xlabel("Dim 1")
219
    ax.set_ylabel("Dim 2")
220
    ax.set_zlabel("Dim 3")
221
    # ax.set_axis_off()
222
223
    return fig