|
a |
|
b/scvae/analyses/figures/scatter.py |
|
|
1 |
# ======================================================================== # |
|
|
2 |
# |
|
|
3 |
# Copyright (c) 2017 - 2020 scVAE authors |
|
|
4 |
# |
|
|
5 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
6 |
# you may not use this file except in compliance with the License. |
|
|
7 |
# You may obtain a copy of the License at |
|
|
8 |
# |
|
|
9 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
10 |
# |
|
|
11 |
# Unless required by applicable law or agreed to in writing, software |
|
|
12 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
14 |
# See the License for the specific language governing permissions and |
|
|
15 |
# limitations under the License. |
|
|
16 |
# |
|
|
17 |
# ======================================================================== # |
|
|
18 |
|
|
|
19 |
import numpy |
|
|
20 |
import scipy |
|
|
21 |
import seaborn |
|
|
22 |
from matplotlib import pyplot |
|
|
23 |
|
|
|
24 |
from scvae.analyses.figures import saving, style |
|
|
25 |
from scvae.analyses.figures.utilities import _covariance_matrix_as_ellipse |
|
|
26 |
from scvae.utilities import normalise_string, capitalise_string |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def plot_values(values, colour_coding=None, colouring_data_set=None, |
|
|
30 |
centroids=None, sampled_values=None, class_name=None, |
|
|
31 |
feature_index=None, figure_labels=None, example_tag=None, |
|
|
32 |
name="scatter"): |
|
|
33 |
|
|
|
34 |
figure_name = name |
|
|
35 |
|
|
|
36 |
if figure_labels: |
|
|
37 |
title = figure_labels.get("title") |
|
|
38 |
x_label = figure_labels.get("x label") |
|
|
39 |
y_label = figure_labels.get("y label") |
|
|
40 |
else: |
|
|
41 |
title = "none" |
|
|
42 |
x_label = "$x$" |
|
|
43 |
y_label = "$y$" |
|
|
44 |
|
|
|
45 |
if not title: |
|
|
46 |
title = "none" |
|
|
47 |
|
|
|
48 |
figure_name += "-" + normalise_string(title) |
|
|
49 |
|
|
|
50 |
if colour_coding: |
|
|
51 |
colour_coding = normalise_string(colour_coding) |
|
|
52 |
figure_name += "-" + colour_coding |
|
|
53 |
if "predicted" in colour_coding: |
|
|
54 |
if colouring_data_set.prediction_specifications: |
|
|
55 |
figure_name += "-" + ( |
|
|
56 |
colouring_data_set.prediction_specifications.name) |
|
|
57 |
else: |
|
|
58 |
figure_name += "unknown_prediction_method" |
|
|
59 |
if colouring_data_set is None: |
|
|
60 |
raise ValueError("Colouring data set not given.") |
|
|
61 |
|
|
|
62 |
if sampled_values is not None: |
|
|
63 |
figure_name += "-samples" |
|
|
64 |
|
|
|
65 |
values = values.copy()[:, :2] |
|
|
66 |
if scipy.sparse.issparse(values): |
|
|
67 |
values = values.A |
|
|
68 |
|
|
|
69 |
# Randomise examples in values to remove any prior order |
|
|
70 |
n_examples, __ = values.shape |
|
|
71 |
random_state = numpy.random.RandomState(117) |
|
|
72 |
shuffled_indices = random_state.permutation(n_examples) |
|
|
73 |
values = values[shuffled_indices] |
|
|
74 |
|
|
|
75 |
# Adjust marker size based on number of examples |
|
|
76 |
style._adjust_marker_size_for_scatter_plots(n_examples) |
|
|
77 |
|
|
|
78 |
figure = pyplot.figure() |
|
|
79 |
axis = figure.add_subplot(1, 1, 1) |
|
|
80 |
seaborn.despine() |
|
|
81 |
|
|
|
82 |
axis.set_xlabel(x_label) |
|
|
83 |
axis.set_ylabel(y_label) |
|
|
84 |
|
|
|
85 |
colour_map = seaborn.dark_palette(style.STANDARD_PALETTE[0], as_cmap=True) |
|
|
86 |
|
|
|
87 |
alpha = 1 |
|
|
88 |
if sampled_values is not None: |
|
|
89 |
alpha = 0.5 |
|
|
90 |
|
|
|
91 |
if colour_coding and ( |
|
|
92 |
"labels" in colour_coding |
|
|
93 |
or "ids" in colour_coding |
|
|
94 |
or "class" in colour_coding |
|
|
95 |
or colour_coding == "batches"): |
|
|
96 |
|
|
|
97 |
if colour_coding == "predicted_cluster_ids": |
|
|
98 |
labels = colouring_data_set.predicted_cluster_ids |
|
|
99 |
class_names = numpy.unique(labels).tolist() |
|
|
100 |
number_of_classes = len(class_names) |
|
|
101 |
class_palette = None |
|
|
102 |
label_sorter = None |
|
|
103 |
elif colour_coding == "predicted_labels": |
|
|
104 |
labels = colouring_data_set.predicted_labels |
|
|
105 |
class_names = colouring_data_set.predicted_class_names |
|
|
106 |
number_of_classes = colouring_data_set.number_of_predicted_classes |
|
|
107 |
class_palette = colouring_data_set.predicted_class_palette |
|
|
108 |
label_sorter = colouring_data_set.predicted_label_sorter |
|
|
109 |
elif colour_coding == "predicted_superset_labels": |
|
|
110 |
labels = colouring_data_set.predicted_superset_labels |
|
|
111 |
class_names = colouring_data_set.predicted_superset_class_names |
|
|
112 |
number_of_classes = ( |
|
|
113 |
colouring_data_set.number_of_predicted_superset_classes) |
|
|
114 |
class_palette = colouring_data_set.predicted_superset_class_palette |
|
|
115 |
label_sorter = colouring_data_set.predicted_superset_label_sorter |
|
|
116 |
elif "superset" in colour_coding: |
|
|
117 |
labels = colouring_data_set.superset_labels |
|
|
118 |
class_names = colouring_data_set.superset_class_names |
|
|
119 |
number_of_classes = colouring_data_set.number_of_superset_classes |
|
|
120 |
class_palette = colouring_data_set.superset_class_palette |
|
|
121 |
label_sorter = colouring_data_set.superset_label_sorter |
|
|
122 |
elif colour_coding == "batches": |
|
|
123 |
labels = colouring_data_set.batch_indices.flatten() |
|
|
124 |
class_names = colouring_data_set.batch_names |
|
|
125 |
number_of_classes = colouring_data_set.number_of_batches |
|
|
126 |
class_palette = None |
|
|
127 |
label_sorter = None |
|
|
128 |
else: |
|
|
129 |
labels = colouring_data_set.labels |
|
|
130 |
class_names = colouring_data_set.class_names |
|
|
131 |
number_of_classes = colouring_data_set.number_of_classes |
|
|
132 |
class_palette = colouring_data_set.class_palette |
|
|
133 |
label_sorter = colouring_data_set.label_sorter |
|
|
134 |
|
|
|
135 |
if not class_palette: |
|
|
136 |
index_palette = style.lighter_palette(number_of_classes) |
|
|
137 |
class_palette = { |
|
|
138 |
class_name: index_palette[i] for i, class_name in |
|
|
139 |
enumerate(sorted(class_names, key=label_sorter)) |
|
|
140 |
} |
|
|
141 |
|
|
|
142 |
# Examples are shuffled, so should their labels be |
|
|
143 |
labels = labels[shuffled_indices] |
|
|
144 |
|
|
|
145 |
if ("labels" in colour_coding or "ids" in colour_coding |
|
|
146 |
or colour_coding == "batches"): |
|
|
147 |
colours = [] |
|
|
148 |
classes = set() |
|
|
149 |
|
|
|
150 |
for i, label in enumerate(labels): |
|
|
151 |
colour = class_palette[label] |
|
|
152 |
colours.append(colour) |
|
|
153 |
|
|
|
154 |
# Plot one example for each class to add labels |
|
|
155 |
if label not in classes: |
|
|
156 |
classes.add(label) |
|
|
157 |
axis.scatter( |
|
|
158 |
values[i, 0], |
|
|
159 |
values[i, 1], |
|
|
160 |
color=colour, |
|
|
161 |
label=label, |
|
|
162 |
alpha=alpha |
|
|
163 |
) |
|
|
164 |
|
|
|
165 |
axis.scatter(values[:, 0], values[:, 1], c=colours, alpha=alpha) |
|
|
166 |
|
|
|
167 |
class_handles, class_labels = axis.get_legend_handles_labels() |
|
|
168 |
|
|
|
169 |
if class_labels: |
|
|
170 |
class_labels, class_handles = zip(*sorted( |
|
|
171 |
zip(class_labels, class_handles), |
|
|
172 |
key=( |
|
|
173 |
lambda t: label_sorter(t[0])) if label_sorter else None |
|
|
174 |
)) |
|
|
175 |
class_label_maximum_width = max(map(len, class_labels)) |
|
|
176 |
if class_label_maximum_width <= 5 and number_of_classes <= 20: |
|
|
177 |
axis.legend( |
|
|
178 |
class_handles, class_labels, |
|
|
179 |
loc="best" |
|
|
180 |
) |
|
|
181 |
else: |
|
|
182 |
if number_of_classes <= 20: |
|
|
183 |
class_label_columns = 2 |
|
|
184 |
else: |
|
|
185 |
class_label_columns = 3 |
|
|
186 |
axis.legend( |
|
|
187 |
class_handles, |
|
|
188 |
class_labels, |
|
|
189 |
bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95), |
|
|
190 |
loc="lower left", |
|
|
191 |
ncol=class_label_columns, |
|
|
192 |
mode="expand", |
|
|
193 |
borderaxespad=0., |
|
|
194 |
) |
|
|
195 |
|
|
|
196 |
elif "class" in colour_coding: |
|
|
197 |
colours = [] |
|
|
198 |
figure_name += "-" + normalise_string(str(class_name)) |
|
|
199 |
ordered_indices_set = { |
|
|
200 |
str(class_name): [], |
|
|
201 |
"Remaining": [] |
|
|
202 |
} |
|
|
203 |
|
|
|
204 |
for i, label in enumerate(labels): |
|
|
205 |
if label == class_name: |
|
|
206 |
colour = class_palette[label] |
|
|
207 |
ordered_indices_set[str(class_name)].append(i) |
|
|
208 |
else: |
|
|
209 |
colour = style.NEUTRAL_COLOUR |
|
|
210 |
ordered_indices_set["Remaining"].append(i) |
|
|
211 |
colours.append(colour) |
|
|
212 |
|
|
|
213 |
colours = numpy.array(colours) |
|
|
214 |
|
|
|
215 |
z_order_index = 1 |
|
|
216 |
for label, ordered_indices in sorted(ordered_indices_set.items()): |
|
|
217 |
if label == "Remaining": |
|
|
218 |
z_order = 0 |
|
|
219 |
else: |
|
|
220 |
z_order = z_order_index |
|
|
221 |
z_order_index += 1 |
|
|
222 |
ordered_values = values[ordered_indices] |
|
|
223 |
ordered_colours = colours[ordered_indices] |
|
|
224 |
axis.scatter( |
|
|
225 |
ordered_values[:, 0], |
|
|
226 |
ordered_values[:, 1], |
|
|
227 |
c=ordered_colours, |
|
|
228 |
label=label, |
|
|
229 |
alpha=alpha, |
|
|
230 |
zorder=z_order |
|
|
231 |
) |
|
|
232 |
|
|
|
233 |
handles, labels = axis.get_legend_handles_labels() |
|
|
234 |
labels, handles = zip(*sorted( |
|
|
235 |
zip(labels, handles), |
|
|
236 |
key=lambda t: label_sorter(t[0]) if label_sorter else None |
|
|
237 |
)) |
|
|
238 |
axis.legend( |
|
|
239 |
handles, |
|
|
240 |
labels, |
|
|
241 |
bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95), |
|
|
242 |
loc="lower left", |
|
|
243 |
ncol=2, |
|
|
244 |
mode="expand", |
|
|
245 |
borderaxespad=0. |
|
|
246 |
) |
|
|
247 |
|
|
|
248 |
elif colour_coding == "count_sum": |
|
|
249 |
|
|
|
250 |
n = colouring_data_set.count_sum[shuffled_indices].flatten() |
|
|
251 |
scatter_plot = axis.scatter( |
|
|
252 |
values[:, 0], |
|
|
253 |
values[:, 1], |
|
|
254 |
c=n, |
|
|
255 |
cmap=colour_map, |
|
|
256 |
alpha=alpha |
|
|
257 |
) |
|
|
258 |
colour_bar = figure.colorbar(scatter_plot) |
|
|
259 |
colour_bar.outline.set_linewidth(0) |
|
|
260 |
colour_bar.set_label("Total number of {}s per {}".format( |
|
|
261 |
colouring_data_set.terms["item"], |
|
|
262 |
colouring_data_set.terms["example"] |
|
|
263 |
)) |
|
|
264 |
|
|
|
265 |
elif colour_coding == "feature": |
|
|
266 |
if feature_index is None: |
|
|
267 |
raise ValueError("Feature number not given.") |
|
|
268 |
if feature_index > colouring_data_set.number_of_features: |
|
|
269 |
raise ValueError("Feature number higher than number of features.") |
|
|
270 |
|
|
|
271 |
feature_name = colouring_data_set.feature_names[feature_index] |
|
|
272 |
figure_name += "-{}".format(normalise_string(feature_name)) |
|
|
273 |
|
|
|
274 |
f = colouring_data_set.values[shuffled_indices, feature_index] |
|
|
275 |
if scipy.sparse.issparse(f): |
|
|
276 |
f = f.A |
|
|
277 |
f = f.squeeze() |
|
|
278 |
|
|
|
279 |
scatter_plot = axis.scatter( |
|
|
280 |
values[:, 0], |
|
|
281 |
values[:, 1], |
|
|
282 |
c=f, |
|
|
283 |
cmap=colour_map, |
|
|
284 |
alpha=alpha |
|
|
285 |
) |
|
|
286 |
colour_bar = figure.colorbar(scatter_plot) |
|
|
287 |
colour_bar.outline.set_linewidth(0) |
|
|
288 |
colour_bar.set_label(feature_name) |
|
|
289 |
|
|
|
290 |
elif colour_coding is None: |
|
|
291 |
axis.scatter( |
|
|
292 |
values[:, 0], values[:, 1], c="k", |
|
|
293 |
alpha=alpha, edgecolors="none") |
|
|
294 |
|
|
|
295 |
else: |
|
|
296 |
raise ValueError( |
|
|
297 |
"Colour coding `{}` not found.".format(colour_coding)) |
|
|
298 |
|
|
|
299 |
if centroids: |
|
|
300 |
prior_centroids = centroids["prior"] |
|
|
301 |
|
|
|
302 |
if prior_centroids: |
|
|
303 |
n_centroids = prior_centroids["probabilities"].shape[0] |
|
|
304 |
else: |
|
|
305 |
n_centroids = 0 |
|
|
306 |
|
|
|
307 |
if n_centroids > 1: |
|
|
308 |
centroids_palette = style.darker_palette(n_centroids) |
|
|
309 |
classes = numpy.arange(n_centroids) |
|
|
310 |
|
|
|
311 |
means = prior_centroids["means"] |
|
|
312 |
covariance_matrices = prior_centroids["covariance_matrices"] |
|
|
313 |
|
|
|
314 |
for k in range(n_centroids): |
|
|
315 |
axis.scatter( |
|
|
316 |
means[k, 0], |
|
|
317 |
means[k, 1], |
|
|
318 |
s=60, |
|
|
319 |
marker="x", |
|
|
320 |
color="black", |
|
|
321 |
linewidth=3 |
|
|
322 |
) |
|
|
323 |
axis.scatter( |
|
|
324 |
means[k, 0], |
|
|
325 |
means[k, 1], |
|
|
326 |
marker="x", |
|
|
327 |
facecolor=centroids_palette[k], |
|
|
328 |
edgecolors="black" |
|
|
329 |
) |
|
|
330 |
ellipse_fill, ellipse_edge = _covariance_matrix_as_ellipse( |
|
|
331 |
covariance_matrices[k], |
|
|
332 |
means[k], |
|
|
333 |
colour=centroids_palette[k] |
|
|
334 |
) |
|
|
335 |
axis.add_patch(ellipse_edge) |
|
|
336 |
axis.add_patch(ellipse_fill) |
|
|
337 |
|
|
|
338 |
if sampled_values is not None: |
|
|
339 |
|
|
|
340 |
sampled_values = sampled_values.copy()[:, :2] |
|
|
341 |
if scipy.sparse.issparse(sampled_values): |
|
|
342 |
sampled_values = sampled_values.A |
|
|
343 |
|
|
|
344 |
sample_colour_map = seaborn.blend_palette( |
|
|
345 |
("white", "purple"), as_cmap=True) |
|
|
346 |
|
|
|
347 |
x_limits = axis.get_xlim() |
|
|
348 |
y_limits = axis.get_ylim() |
|
|
349 |
|
|
|
350 |
axis.hexbin( |
|
|
351 |
sampled_values[:, 0], sampled_values[:, 1], |
|
|
352 |
gridsize=75, |
|
|
353 |
cmap=sample_colour_map, |
|
|
354 |
linewidths=0., edgecolors="none", |
|
|
355 |
zorder=-100 |
|
|
356 |
) |
|
|
357 |
|
|
|
358 |
axis.set_xlim(x_limits) |
|
|
359 |
axis.set_ylim(y_limits) |
|
|
360 |
|
|
|
361 |
# Reset marker size |
|
|
362 |
style.reset_plot_look() |
|
|
363 |
|
|
|
364 |
return figure, figure_name |
|
|
365 |
|
|
|
366 |
|
|
|
367 |
def plot_variable_correlations(values, variable_names=None, |
|
|
368 |
colouring_data_set=None, |
|
|
369 |
name="variable_correlations"): |
|
|
370 |
|
|
|
371 |
figure_name = saving.build_figure_name(name) |
|
|
372 |
n_examples, n_features = values.shape |
|
|
373 |
|
|
|
374 |
random_state = numpy.random.RandomState(117) |
|
|
375 |
shuffled_indices = random_state.permutation(n_examples) |
|
|
376 |
values = values[shuffled_indices] |
|
|
377 |
|
|
|
378 |
if colouring_data_set: |
|
|
379 |
labels = colouring_data_set.labels |
|
|
380 |
class_names = colouring_data_set.class_names |
|
|
381 |
number_of_classes = colouring_data_set.number_of_classes |
|
|
382 |
class_palette = colouring_data_set.class_palette |
|
|
383 |
label_sorter = colouring_data_set.label_sorter |
|
|
384 |
|
|
|
385 |
if not class_palette: |
|
|
386 |
index_palette = style.lighter_palette(number_of_classes) |
|
|
387 |
class_palette = { |
|
|
388 |
class_name: index_palette[i] for i, class_name in |
|
|
389 |
enumerate(sorted(class_names, key=label_sorter)) |
|
|
390 |
} |
|
|
391 |
|
|
|
392 |
labels = labels[shuffled_indices] |
|
|
393 |
|
|
|
394 |
colours = [] |
|
|
395 |
|
|
|
396 |
for label in labels: |
|
|
397 |
colour = class_palette[label] |
|
|
398 |
colours.append(colour) |
|
|
399 |
|
|
|
400 |
else: |
|
|
401 |
colours = style.NEUTRAL_COLOUR |
|
|
402 |
|
|
|
403 |
figure, axes = pyplot.subplots( |
|
|
404 |
nrows=n_features, |
|
|
405 |
ncols=n_features, |
|
|
406 |
figsize=[1.5 * n_features] * 2 |
|
|
407 |
) |
|
|
408 |
|
|
|
409 |
for i in range(n_features): |
|
|
410 |
for j in range(n_features): |
|
|
411 |
axes[i, j].scatter(values[:, i], values[:, j], c=colours, s=1) |
|
|
412 |
|
|
|
413 |
axes[i, j].set_xticks([]) |
|
|
414 |
axes[i, j].set_yticks([]) |
|
|
415 |
|
|
|
416 |
if i == n_features - 1: |
|
|
417 |
axes[i, j].set_xlabel(variable_names[j]) |
|
|
418 |
|
|
|
419 |
axes[i, 0].set_ylabel(variable_names[i]) |
|
|
420 |
|
|
|
421 |
return figure, figure_name |
|
|
422 |
|
|
|
423 |
|
|
|
424 |
def plot_variable_label_correlations(variable_vector, variable_name, |
|
|
425 |
colouring_data_set, |
|
|
426 |
name="variable_label_correlations"): |
|
|
427 |
|
|
|
428 |
figure_name = saving.build_figure_name(name) |
|
|
429 |
n_examples = variable_vector.shape[0] |
|
|
430 |
|
|
|
431 |
class_names_to_class_ids = numpy.vectorize( |
|
|
432 |
lambda class_name: |
|
|
433 |
colouring_data_set.class_name_to_class_id[class_name] |
|
|
434 |
) |
|
|
435 |
class_ids_to_class_names = numpy.vectorize( |
|
|
436 |
lambda class_name: |
|
|
437 |
colouring_data_set.class_id_to_class_name[class_name] |
|
|
438 |
) |
|
|
439 |
|
|
|
440 |
labels = colouring_data_set.labels |
|
|
441 |
class_names = colouring_data_set.class_names |
|
|
442 |
number_of_classes = colouring_data_set.number_of_classes |
|
|
443 |
class_palette = colouring_data_set.class_palette |
|
|
444 |
label_sorter = colouring_data_set.label_sorter |
|
|
445 |
|
|
|
446 |
if not class_palette: |
|
|
447 |
index_palette = style.lighter_palette(number_of_classes) |
|
|
448 |
class_palette = { |
|
|
449 |
class_name: index_palette[i] for i, class_name in |
|
|
450 |
enumerate(sorted(class_names, key=label_sorter)) |
|
|
451 |
} |
|
|
452 |
|
|
|
453 |
random_state = numpy.random.RandomState(117) |
|
|
454 |
shuffled_indices = random_state.permutation(n_examples) |
|
|
455 |
variable_vector = variable_vector[shuffled_indices] |
|
|
456 |
|
|
|
457 |
labels = labels[shuffled_indices] |
|
|
458 |
label_ids = numpy.expand_dims(class_names_to_class_ids(labels), axis=-1) |
|
|
459 |
colours = [class_palette[label] for label in labels] |
|
|
460 |
|
|
|
461 |
unique_class_ids = numpy.unique(label_ids) |
|
|
462 |
unique_class_names = class_ids_to_class_names(unique_class_ids) |
|
|
463 |
|
|
|
464 |
figure = pyplot.figure() |
|
|
465 |
axis = figure.add_subplot(1, 1, 1) |
|
|
466 |
seaborn.despine() |
|
|
467 |
|
|
|
468 |
axis.scatter(variable_vector, label_ids, c=colours, s=1) |
|
|
469 |
|
|
|
470 |
axis.set_yticks(unique_class_ids) |
|
|
471 |
axis.set_yticklabels(unique_class_names) |
|
|
472 |
|
|
|
473 |
axis.set_xlabel(variable_name) |
|
|
474 |
axis.set_ylabel(capitalise_string(colouring_data_set.terms["class"])) |
|
|
475 |
|
|
|
476 |
return figure, figure_name |