Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery
[c23b31]: / src / move / visualization / latent_space.py

Download this file

224 lines (201 with data), 7.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
__all__ = ["plot_latent_space_with_cat", "plot_latent_space_with_con"]
from typing import Any
import matplotlib.cm as cm
import matplotlib.figure
import matplotlib.pyplot as plt
import matplotlib.style
import numpy as np
from matplotlib.colors import Normalize, TwoSlopeNorm
from move.core.typing import BoolArray, FloatArray
from move.visualization.style import (
DEFAULT_DIVERGING_PALETTE,
DEFAULT_PLOT_STYLE,
DEFAULT_QUALITATIVE_PALETTE,
color_cycle,
style_settings,
)
def plot_latent_space_with_cat(
latent_space: FloatArray,
feature_name: str,
feature_values: FloatArray,
feature_mapping: dict[str, Any],
is_nan: BoolArray,
style: str = DEFAULT_PLOT_STYLE,
colormap: str = DEFAULT_QUALITATIVE_PALETTE,
) -> matplotlib.figure.Figure:
"""Plot a 2D latent space together with a legend mapping the latent
space to the values of a discrete feature.
Args:
latent_space:
Embedding, a ND array with at least two dimensions.
feature_name:
Name of categorical feature
feature_values:
Values of categorical feature
feature_mapping:
Mapping of codes to categories for the categorical feature
is_nan:
Array of bool values indicating which feature values are NaNs
style:
Name of style to apply to the plot
colormap:
Name of qualitative colormap to use for each category
Raises:
ValueError: If latent space does not have at least two dimensions.
Returns:
Figure
"""
if latent_space.ndim < 2:
raise ValueError("Expected at least two dimensions in latent space.")
with style_settings(style), color_cycle(colormap):
fig, ax = plt.subplots()
codes = np.unique(feature_values)
for code in codes:
category = feature_mapping[str(code)]
is_category = (feature_values == code) & ~is_nan
dims = np.take(latent_space.compress(is_category, axis=0), [0, 1], axis=1).T
ax.scatter(*dims, label=category)
dims = np.take(latent_space.compress(is_nan, axis=0), [0, 1], axis=1).T
ax.scatter(*dims, label="NaN")
ax.set(xlabel="dim 0", ylabel="dim 1")
legend = ax.legend()
legend.set_title(feature_name)
return fig
def plot_latent_space_with_con(
latent_space: FloatArray,
feature_name: str,
feature_values: FloatArray,
style: str = DEFAULT_PLOT_STYLE,
colormap: str = DEFAULT_DIVERGING_PALETTE,
) -> matplotlib.figure.Figure:
"""Plot a 2D latent space together with a colorbar mapping the latent
space to the values of a continuous feature.
Args:
latent_space: Embedding, a ND array with at least two dimensions.
feature_name: Name of continuous feature
feature_values: Values of continuous feature
style: Name of style to apply to the plot
colormap: Name of colormap to use for the colorbar
Raises:
ValueError: If latent space does not have at least two dimensions.
Returns:
Figure
"""
if latent_space.ndim < 2:
raise ValueError("Expected at least two dimensions in latent space.")
norm = TwoSlopeNorm(0.0, min(feature_values), max(feature_values))
with style_settings(style):
fig, ax = plt.subplots()
dims = latent_space[:, 0], latent_space[:, 1]
pts = ax.scatter(*dims, c=feature_values, cmap=colormap, norm=norm)
cbar = fig.colorbar(pts, ax=ax)
cbar.ax.set(ylabel=feature_name)
ax.set(xlabel="dim 0", ylabel="dim 1")
return fig
def plot_3D_latent_and_displacement(
mu_baseline: FloatArray,
mu_perturbed: FloatArray,
feature_values: FloatArray,
feature_name: str,
show_baseline: bool = True,
show_perturbed: bool = True,
show_arrows: bool = True,
step: int = 1,
altitude: int = 30,
azimuth: int = 45,
) -> matplotlib.figure.Figure:
"""
Plot the movement of the samples in the 3D latent space after perturbing one
input variable.
Args:
mu_baseline:
ND array with dimensions n_samples x n_latent_nodes containing
the latent representation of each sample
mu_perturbed:
ND array with dimensions n_samples x n_latent_nodes containing
the latent representation of each sample after perturbing the input
feature_values:
1D array with feature values to map to a colormap ("bwr"). Each sample is
colored according to its value for the feature of interest.
feature_name:
name of the feature mapped to a colormap
show_baseline:
plot orginal location of the samples in the latent space
show_perturbed:
plot final location (after perturbation) of the samples in latent space
show_arrows:
plot arrows from original to final location of each sample
angle:
elevation from dim1-dim2 plane for the visualization of latent space.
Raises:
ValueError: If latent space is not 3-dimensional (3 hidden nodes).
Returns:
Figure
"""
if [np.shape(mu_baseline)[1], np.shape(mu_perturbed)[1]] != [3, 3]:
raise ValueError(
" The latent space must be 3-dimensional. Redefine num_latent to 3."
)
fig = plt.figure(layout="constrained", figsize=(10, 10))
ax = fig.add_subplot(projection="3d")
ax.view_init(altitude, azimuth)
if show_baseline:
# vmin, vmax = np.min(feature_values[::step]), np.max(feature_values[::step])
# abs_max = np.max([abs(vmin), abs(vmax)])
ax.scatter(
mu_baseline[::step, 0],
mu_baseline[::step, 1],
mu_baseline[::step, 2],
marker=".",
c=feature_values[::step],
s=10,
lw=0,
cmap="seismic",
vmin=-2,
vmax=2,
)
ax.set_title(feature_name)
fig.colorbar(
cm.ScalarMappable(cmap="seismic", norm=Normalize(-2, 2)), ax=ax
) # Normalize(min(feature_values[::step]),max(feature_values[::step]))), ax=ax)
if show_perturbed:
ax.scatter(
mu_perturbed[::step, 0],
mu_perturbed[::step, 1],
mu_perturbed[::step, 2],
marker=".",
color="lightblue",
label="perturbed",
lw=0.5,
)
if show_arrows:
u = mu_perturbed[::step, 0] - mu_baseline[::step, 0]
v = mu_perturbed[::step, 1] - mu_baseline[::step, 1]
w = mu_perturbed[::step, 2] - mu_baseline[::step, 2]
# module = np.sqrt(u * u + v * v + w * w)
max_u, max_v, max_w = np.max(abs(u)), np.max(abs(v)), np.max(abs(w))
# Arrow colors will be weighted contributions of
# red -> dim1,
# green -> dim2,
# and blue-> dim3.
# I.e. purple arrow means movement in dims 1 and 3
colors = [
(abs(du) / max_u, abs(dv) / max_v, abs(dw) / max_w, 0.7)
for du, dv, dw in zip(u, v, w)
]
ax.quiver(
mu_baseline[::step, 0],
mu_baseline[::step, 1],
mu_baseline[::step, 2],
u,
v,
w,
color=colors,
lw=0.8,
) # alpha=(1-module/np.max(module))**6, arrow_length_ratio=0)
# help(ax.quiver)
ax.set_xlabel("Dim 1")
ax.set_ylabel("Dim 2")
ax.set_zlabel("Dim 3")
# ax.set_axis_off()
return fig