|
a |
|
b/src/move/visualization/dataset_distributions.py |
|
|
1 |
__all__ = ["plot_value_distributions"] |
|
|
2 |
|
|
|
3 |
import matplotlib |
|
|
4 |
import matplotlib.figure |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
import networkx as nx |
|
|
7 |
import numpy as np |
|
|
8 |
|
|
|
9 |
from move.core.typing import FloatArray |
|
|
10 |
from move.visualization.style import ( |
|
|
11 |
DEFAULT_DIVERGING_PALETTE, |
|
|
12 |
DEFAULT_PLOT_STYLE, |
|
|
13 |
style_settings, |
|
|
14 |
) |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def plot_value_distributions( |
|
|
18 |
feature_values: FloatArray, |
|
|
19 |
style: str = "fast", |
|
|
20 |
nbins: int = 100, |
|
|
21 |
colormap: str = DEFAULT_DIVERGING_PALETTE, |
|
|
22 |
) -> matplotlib.figure.Figure: |
|
|
23 |
""" |
|
|
24 |
Given a certain dataset, plot its distribution of values. |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
Args: |
|
|
28 |
feature_values: |
|
|
29 |
Values of the features, a 2D array (`num_samples` x `num_features`). |
|
|
30 |
style: |
|
|
31 |
Name of style to apply to the plot. |
|
|
32 |
colormap: |
|
|
33 |
Name of colormap to apply to the colorbar. |
|
|
34 |
|
|
|
35 |
Returns: |
|
|
36 |
Figure |
|
|
37 |
""" |
|
|
38 |
vmin, vmax = np.nanmin(feature_values), np.nanmax(feature_values) |
|
|
39 |
with style_settings(style): |
|
|
40 |
fig = plt.figure(layout="constrained", figsize=(7, 7)) |
|
|
41 |
ax = fig.add_subplot(projection="3d") |
|
|
42 |
x_val = np.linspace(vmin, vmax, nbins) |
|
|
43 |
y_val = np.arange(np.shape(feature_values)[1]) |
|
|
44 |
x_val, y_val = np.meshgrid(x_val, y_val) |
|
|
45 |
|
|
|
46 |
histogram = [] |
|
|
47 |
for i in range(np.shape(feature_values)[1]): |
|
|
48 |
feat_i_list = feature_values[:, i] |
|
|
49 |
feat_hist, feat_bin_edges = np.histogram( |
|
|
50 |
feat_i_list, bins=nbins, range=(vmin, vmax) |
|
|
51 |
) |
|
|
52 |
histogram.append(feat_hist) |
|
|
53 |
|
|
|
54 |
ax.plot_surface(x_val, y_val, np.array(histogram), cmap=colormap) |
|
|
55 |
ax.set_xlabel("Feature value") |
|
|
56 |
ax.set_ylabel("Feature ID number") |
|
|
57 |
ax.set_zlabel("Frequency") |
|
|
58 |
# ax.legend() |
|
|
59 |
return fig |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def plot_reconstruction_diff( |
|
|
63 |
diff_array: FloatArray, |
|
|
64 |
vmin=None, |
|
|
65 |
vmax=None, |
|
|
66 |
style: str = DEFAULT_PLOT_STYLE, |
|
|
67 |
colormap: str = DEFAULT_DIVERGING_PALETTE, |
|
|
68 |
) -> matplotlib.figure.Figure: |
|
|
69 |
""" |
|
|
70 |
Plot the reconstruction differences as a heatmap. |
|
|
71 |
""" |
|
|
72 |
with style_settings(style): |
|
|
73 |
if vmin is None: |
|
|
74 |
vmin = np.min(diff_array) |
|
|
75 |
elif vmax is None: |
|
|
76 |
vmax = np.max(diff_array) |
|
|
77 |
fig = plt.figure(layout="constrained", figsize=(7, 7)) |
|
|
78 |
plt.imshow(diff_array, cmap=colormap, vmin=vmin, vmax=vmax) |
|
|
79 |
plt.xlabel("Feature") |
|
|
80 |
plt.ylabel("Sample") |
|
|
81 |
plt.colorbar() |
|
|
82 |
|
|
|
83 |
return fig |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
def plot_feature_association_graph( |
|
|
87 |
association_df, output_path, layout="circular", style: str = DEFAULT_PLOT_STYLE |
|
|
88 |
) -> matplotlib.figure.Figure: |
|
|
89 |
""" |
|
|
90 |
This function plots a graph where each node corresponds to a feature and the edges |
|
|
91 |
represent the associations between features. Edge width represents the probability |
|
|
92 |
of said association, not the association's effect size. |
|
|
93 |
|
|
|
94 |
Input: |
|
|
95 |
association_df: pandas dataframe containing the following columns: |
|
|
96 |
- feature_a: source node |
|
|
97 |
- feature_b: target node |
|
|
98 |
- p_value/bayes_score: edge weight |
|
|
99 |
output_path: Path object where the picture will be stored. |
|
|
100 |
|
|
|
101 |
Output: |
|
|
102 |
Feature_association_graph.png: picture of the graph |
|
|
103 |
|
|
|
104 |
""" |
|
|
105 |
|
|
|
106 |
if "p_value" in association_df.columns: |
|
|
107 |
association_df["weight"] = 1 - association_df["p_value"] |
|
|
108 |
|
|
|
109 |
elif "proba" in association_df.columns: |
|
|
110 |
association_df["weight"] = association_df["proba"] |
|
|
111 |
|
|
|
112 |
elif "ks_distance" in association_df.columns: |
|
|
113 |
association_df["weight"] = association_df["ks_distance"] |
|
|
114 |
|
|
|
115 |
with style_settings(style): |
|
|
116 |
fig = plt.figure(figsize=(45, 45)) |
|
|
117 |
G = nx.from_pandas_edgelist( |
|
|
118 |
association_df, |
|
|
119 |
source="feature_a_name", |
|
|
120 |
target="feature_b_name", |
|
|
121 |
edge_attr="weight", |
|
|
122 |
) |
|
|
123 |
|
|
|
124 |
nodes = list(G.nodes) |
|
|
125 |
|
|
|
126 |
datasets = association_df["feature_b_dataset"].unique() |
|
|
127 |
color_map = { |
|
|
128 |
dataset: (np.random.uniform(), np.random.uniform(), np.random.uniform()) |
|
|
129 |
for dataset in datasets |
|
|
130 |
} |
|
|
131 |
node_dataset_map = { |
|
|
132 |
target_feature: dataset |
|
|
133 |
for (target_feature, dataset) in zip( |
|
|
134 |
association_df["feature_b_name"], association_df["feature_b_dataset"] |
|
|
135 |
) |
|
|
136 |
} |
|
|
137 |
|
|
|
138 |
if layout == "spring": |
|
|
139 |
pos = nx.spring_layout(G) |
|
|
140 |
with_labels = True |
|
|
141 |
elif layout == "circular": |
|
|
142 |
pos = nx.circular_layout(G) |
|
|
143 |
_ = [ |
|
|
144 |
plt.text( |
|
|
145 |
pos[node][0], |
|
|
146 |
pos[node][1], |
|
|
147 |
nodes[i], |
|
|
148 |
rotation=(i / float(len(nodes))) * 360, |
|
|
149 |
fontsize=10, |
|
|
150 |
horizontalalignment="center", |
|
|
151 |
verticalalignment="center", |
|
|
152 |
) |
|
|
153 |
for i, node in enumerate(nodes) |
|
|
154 |
] |
|
|
155 |
with_labels = False |
|
|
156 |
|
|
|
157 |
else: |
|
|
158 |
raise ValueError( |
|
|
159 |
"Graph layout (layout argument) must be either 'circular' or 'spring'." |
|
|
160 |
) |
|
|
161 |
|
|
|
162 |
nx.draw( |
|
|
163 |
G, |
|
|
164 |
pos=pos, |
|
|
165 |
with_labels=with_labels, |
|
|
166 |
node_size=2000, |
|
|
167 |
node_color=[ |
|
|
168 |
( |
|
|
169 |
color_map[node_dataset_map[feature]] |
|
|
170 |
if feature in node_dataset_map.keys() |
|
|
171 |
else "white" |
|
|
172 |
) |
|
|
173 |
for feature in G.nodes |
|
|
174 |
], |
|
|
175 |
edge_color=list(nx.get_edge_attributes(G, "weight").values()), |
|
|
176 |
font_color="black", |
|
|
177 |
font_size=10, |
|
|
178 |
edge_cmap=matplotlib.colormaps["Purples"], |
|
|
179 |
connectionstyle="arc3, rad=1", |
|
|
180 |
) |
|
|
181 |
|
|
|
182 |
plt.tight_layout() |
|
|
183 |
fig.savefig( |
|
|
184 |
output_path / f"Feature_association_graph_{layout}.png", format="png" |
|
|
185 |
) |
|
|
186 |
return fig |
|
|
187 |
|
|
|
188 |
|
|
|
189 |
def plot_feature_mean_median( |
|
|
190 |
array: FloatArray, axis=0, style: str = DEFAULT_PLOT_STYLE |
|
|
191 |
) -> matplotlib.figure.Figure: |
|
|
192 |
""" |
|
|
193 |
Plot feature values together with the mean, median, min and max values |
|
|
194 |
at each array position. |
|
|
195 |
""" |
|
|
196 |
with style_settings(style): |
|
|
197 |
fig = plt.figure(figsize=(15, 3)) |
|
|
198 |
y = np.mean(array, axis=axis) |
|
|
199 |
y_2 = np.median(array, axis=axis) |
|
|
200 |
y_3 = np.max(array, axis=axis) |
|
|
201 |
y_4 = np.min(array, axis=axis) |
|
|
202 |
plt.plot(np.arange(len(y)), y, "bo", label="mean") |
|
|
203 |
plt.plot(np.arange(len(y_2)), y_2, "ro", label="median") |
|
|
204 |
plt.plot(np.arange(len(y_3)), y_3, "go", label="max") |
|
|
205 |
plt.plot(np.arange(len(y_4)), y_4, "yo", label="min") |
|
|
206 |
plt.legend() |
|
|
207 |
plt.xlabel("feature") |
|
|
208 |
plt.ylabel("mean/median/min/max") |
|
|
209 |
|
|
|
210 |
return fig |
|
|
211 |
|
|
|
212 |
|
|
|
213 |
def plot_reconstruction_movement( |
|
|
214 |
baseline_recon: FloatArray, |
|
|
215 |
perturb_recon: FloatArray, |
|
|
216 |
k: int, |
|
|
217 |
style: str = DEFAULT_PLOT_STYLE, |
|
|
218 |
) -> matplotlib.figure.Figure: |
|
|
219 |
""" |
|
|
220 |
Plot, for each sample, the change in value from the unperturbed reconstruction to |
|
|
221 |
the perturbed reconstruction. Blue lines are left/negative shifts, |
|
|
222 |
red lines are right/positive shifts. |
|
|
223 |
|
|
|
224 |
Args: |
|
|
225 |
baseline_recon: baseline reconstruction array with s samples |
|
|
226 |
and k features (s,k). |
|
|
227 |
perturb_recon: perturbed |
|
|
228 |
k: feature index. The shift (movement) of this feature's reconstruction |
|
|
229 |
will be plotted for all samples s. |
|
|
230 |
""" |
|
|
231 |
with style_settings(style): |
|
|
232 |
# Feature changes |
|
|
233 |
fig = plt.figure(figsize=(25, 25)) |
|
|
234 |
for s in range(np.shape(baseline_recon)[0]): |
|
|
235 |
plt.arrow( |
|
|
236 |
baseline_recon[s, k], |
|
|
237 |
s / 100, |
|
|
238 |
perturb_recon[s, k], |
|
|
239 |
0, |
|
|
240 |
length_includes_head=True, |
|
|
241 |
color=["r" if baseline_recon[s, k] < perturb_recon[s, k] else "b"][0], |
|
|
242 |
) |
|
|
243 |
plt.ylabel("Sample (e2)", size=40) |
|
|
244 |
plt.xlabel("Feature_value", size=40) |
|
|
245 |
return fig |
|
|
246 |
|
|
|
247 |
|
|
|
248 |
def plot_cumulative_distributions( |
|
|
249 |
edges: FloatArray, |
|
|
250 |
hist_base: FloatArray, |
|
|
251 |
hist_pert: FloatArray, |
|
|
252 |
title: str, |
|
|
253 |
style: str = DEFAULT_PLOT_STYLE, |
|
|
254 |
) -> matplotlib.figure.Figure: |
|
|
255 |
""" |
|
|
256 |
Plot the cumulative distribution of the histograms for the baseline |
|
|
257 |
and perturbed reconstructions. This is useful to visually assess the |
|
|
258 |
magnitude of the shift corresponding to the KS score. |
|
|
259 |
|
|
|
260 |
""" |
|
|
261 |
|
|
|
262 |
with style_settings(style): |
|
|
263 |
# Cumulative distribution: |
|
|
264 |
fig = plt.figure(figsize=(7, 7)) |
|
|
265 |
plt.plot( |
|
|
266 |
(edges[:-1] + edges[1:]) / 2, |
|
|
267 |
np.cumsum(hist_base), |
|
|
268 |
color="blue", |
|
|
269 |
label="baseline", |
|
|
270 |
alpha=0.5, |
|
|
271 |
) |
|
|
272 |
plt.plot( |
|
|
273 |
(edges[:-1] + edges[1:]) / 2, |
|
|
274 |
np.cumsum(hist_pert), |
|
|
275 |
color="red", |
|
|
276 |
label="Perturbed", |
|
|
277 |
alpha=0.5, |
|
|
278 |
) |
|
|
279 |
|
|
|
280 |
plt.title(f"{title}.png") |
|
|
281 |
plt.xlabel("Feature value") |
|
|
282 |
plt.ylabel("Cumulative distribution") |
|
|
283 |
plt.legend() |
|
|
284 |
|
|
|
285 |
return fig |
|
|
286 |
|
|
|
287 |
|
|
|
288 |
def plot_correlations( |
|
|
289 |
x: FloatArray, |
|
|
290 |
y: FloatArray, |
|
|
291 |
x_pol: FloatArray, |
|
|
292 |
y_pol: FloatArray, |
|
|
293 |
a2: float, |
|
|
294 |
a1: float, |
|
|
295 |
a: float, |
|
|
296 |
k: int, |
|
|
297 |
style: str = DEFAULT_PLOT_STYLE, |
|
|
298 |
) -> matplotlib.figure.Figure: |
|
|
299 |
""" |
|
|
300 |
Plot y vs x and the corresponding polynomial fit. |
|
|
301 |
""" |
|
|
302 |
with style_settings(style): |
|
|
303 |
# Plot correlations |
|
|
304 |
fig = plt.figure(figsize=(7, 7)) |
|
|
305 |
plt.plot(x, y, marker=".", lw=0, markersize=1, color="red") |
|
|
306 |
plt.plot( |
|
|
307 |
x_pol, |
|
|
308 |
y_pol, |
|
|
309 |
color="blue", |
|
|
310 |
label="{0:.2f}x^2 {1:.2f}x {2:.2f}".format(a2, a1, a), |
|
|
311 |
lw=1, |
|
|
312 |
) |
|
|
313 |
plt.plot(x_pol, x_pol, lw=1, color="k") |
|
|
314 |
plt.xlabel(f"Feature {k} baseline values ") |
|
|
315 |
plt.ylabel(f"Feature {k} baseline value reconstruction") |
|
|
316 |
plt.legend() |
|
|
317 |
|
|
|
318 |
return fig |
|
|
319 |
|
|
|
320 |
|
|
|
321 |
def get_2nd_order_polynomial(x_array, y_array, n_points=100): |
|
|
322 |
""" |
|
|
323 |
Given a set of x an y values, find the 2nd oder polynomial fitting best the data. |
|
|
324 |
Returns: |
|
|
325 |
x_pol: x coordinates for the polynomial function evaluation. |
|
|
326 |
y_pol: y coordinates for the polynomial function evaluation. |
|
|
327 |
""" |
|
|
328 |
a2, a1, a = np.polyfit(x_array, y_array, deg=2) |
|
|
329 |
|
|
|
330 |
x_pol = np.linspace(np.min(x_array), np.max(x_array), n_points) |
|
|
331 |
y_pol = a2 * x_pol**2 + a1 * x_pol + a |
|
|
332 |
|
|
|
333 |
return x_pol, y_pol, (a2, a1, a) |