a b/scvae/analyses/figures/scatter.py
1
# ======================================================================== #
2
#
3
# Copyright (c) 2017 - 2020 scVAE authors
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#    http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
#
17
# ======================================================================== #
18
19
import numpy
20
import scipy
21
import seaborn
22
from matplotlib import pyplot
23
24
from scvae.analyses.figures import saving, style
25
from scvae.analyses.figures.utilities import _covariance_matrix_as_ellipse
26
from scvae.utilities import normalise_string, capitalise_string
27
28
29
def plot_values(values, colour_coding=None, colouring_data_set=None,
30
                centroids=None, sampled_values=None, class_name=None,
31
                feature_index=None, figure_labels=None, example_tag=None,
32
                name="scatter"):
33
34
    figure_name = name
35
36
    if figure_labels:
37
        title = figure_labels.get("title")
38
        x_label = figure_labels.get("x label")
39
        y_label = figure_labels.get("y label")
40
    else:
41
        title = "none"
42
        x_label = "$x$"
43
        y_label = "$y$"
44
45
    if not title:
46
        title = "none"
47
48
    figure_name += "-" + normalise_string(title)
49
50
    if colour_coding:
51
        colour_coding = normalise_string(colour_coding)
52
        figure_name += "-" + colour_coding
53
        if "predicted" in colour_coding:
54
            if colouring_data_set.prediction_specifications:
55
                figure_name += "-" + (
56
                    colouring_data_set.prediction_specifications.name)
57
            else:
58
                figure_name += "unknown_prediction_method"
59
        if colouring_data_set is None:
60
            raise ValueError("Colouring data set not given.")
61
62
    if sampled_values is not None:
63
        figure_name += "-samples"
64
65
    values = values.copy()[:, :2]
66
    if scipy.sparse.issparse(values):
67
        values = values.A
68
69
    # Randomise examples in values to remove any prior order
70
    n_examples, __ = values.shape
71
    random_state = numpy.random.RandomState(117)
72
    shuffled_indices = random_state.permutation(n_examples)
73
    values = values[shuffled_indices]
74
75
    # Adjust marker size based on number of examples
76
    style._adjust_marker_size_for_scatter_plots(n_examples)
77
78
    figure = pyplot.figure()
79
    axis = figure.add_subplot(1, 1, 1)
80
    seaborn.despine()
81
82
    axis.set_xlabel(x_label)
83
    axis.set_ylabel(y_label)
84
85
    colour_map = seaborn.dark_palette(style.STANDARD_PALETTE[0], as_cmap=True)
86
87
    alpha = 1
88
    if sampled_values is not None:
89
        alpha = 0.5
90
91
    if colour_coding and (
92
            "labels" in colour_coding
93
            or "ids" in colour_coding
94
            or "class" in colour_coding
95
            or colour_coding == "batches"):
96
97
        if colour_coding == "predicted_cluster_ids":
98
            labels = colouring_data_set.predicted_cluster_ids
99
            class_names = numpy.unique(labels).tolist()
100
            number_of_classes = len(class_names)
101
            class_palette = None
102
            label_sorter = None
103
        elif colour_coding == "predicted_labels":
104
            labels = colouring_data_set.predicted_labels
105
            class_names = colouring_data_set.predicted_class_names
106
            number_of_classes = colouring_data_set.number_of_predicted_classes
107
            class_palette = colouring_data_set.predicted_class_palette
108
            label_sorter = colouring_data_set.predicted_label_sorter
109
        elif colour_coding == "predicted_superset_labels":
110
            labels = colouring_data_set.predicted_superset_labels
111
            class_names = colouring_data_set.predicted_superset_class_names
112
            number_of_classes = (
113
                colouring_data_set.number_of_predicted_superset_classes)
114
            class_palette = colouring_data_set.predicted_superset_class_palette
115
            label_sorter = colouring_data_set.predicted_superset_label_sorter
116
        elif "superset" in colour_coding:
117
            labels = colouring_data_set.superset_labels
118
            class_names = colouring_data_set.superset_class_names
119
            number_of_classes = colouring_data_set.number_of_superset_classes
120
            class_palette = colouring_data_set.superset_class_palette
121
            label_sorter = colouring_data_set.superset_label_sorter
122
        elif colour_coding == "batches":
123
            labels = colouring_data_set.batch_indices.flatten()
124
            class_names = colouring_data_set.batch_names
125
            number_of_classes = colouring_data_set.number_of_batches
126
            class_palette = None
127
            label_sorter = None
128
        else:
129
            labels = colouring_data_set.labels
130
            class_names = colouring_data_set.class_names
131
            number_of_classes = colouring_data_set.number_of_classes
132
            class_palette = colouring_data_set.class_palette
133
            label_sorter = colouring_data_set.label_sorter
134
135
        if not class_palette:
136
            index_palette = style.lighter_palette(number_of_classes)
137
            class_palette = {
138
                class_name: index_palette[i] for i, class_name in
139
                enumerate(sorted(class_names, key=label_sorter))
140
            }
141
142
        # Examples are shuffled, so should their labels be
143
        labels = labels[shuffled_indices]
144
145
        if ("labels" in colour_coding or "ids" in colour_coding
146
                or colour_coding == "batches"):
147
            colours = []
148
            classes = set()
149
150
            for i, label in enumerate(labels):
151
                colour = class_palette[label]
152
                colours.append(colour)
153
154
                # Plot one example for each class to add labels
155
                if label not in classes:
156
                    classes.add(label)
157
                    axis.scatter(
158
                        values[i, 0],
159
                        values[i, 1],
160
                        color=colour,
161
                        label=label,
162
                        alpha=alpha
163
                    )
164
165
            axis.scatter(values[:, 0], values[:, 1], c=colours, alpha=alpha)
166
167
            class_handles, class_labels = axis.get_legend_handles_labels()
168
169
            if class_labels:
170
                class_labels, class_handles = zip(*sorted(
171
                    zip(class_labels, class_handles),
172
                    key=(
173
                        lambda t: label_sorter(t[0])) if label_sorter else None
174
                ))
175
                class_label_maximum_width = max(map(len, class_labels))
176
                if class_label_maximum_width <= 5 and number_of_classes <= 20:
177
                    axis.legend(
178
                        class_handles, class_labels,
179
                        loc="best"
180
                    )
181
                else:
182
                    if number_of_classes <= 20:
183
                        class_label_columns = 2
184
                    else:
185
                        class_label_columns = 3
186
                    axis.legend(
187
                        class_handles,
188
                        class_labels,
189
                        bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95),
190
                        loc="lower left",
191
                        ncol=class_label_columns,
192
                        mode="expand",
193
                        borderaxespad=0.,
194
                    )
195
196
        elif "class" in colour_coding:
197
            colours = []
198
            figure_name += "-" + normalise_string(str(class_name))
199
            ordered_indices_set = {
200
                str(class_name): [],
201
                "Remaining": []
202
            }
203
204
            for i, label in enumerate(labels):
205
                if label == class_name:
206
                    colour = class_palette[label]
207
                    ordered_indices_set[str(class_name)].append(i)
208
                else:
209
                    colour = style.NEUTRAL_COLOUR
210
                    ordered_indices_set["Remaining"].append(i)
211
                colours.append(colour)
212
213
            colours = numpy.array(colours)
214
215
            z_order_index = 1
216
            for label, ordered_indices in sorted(ordered_indices_set.items()):
217
                if label == "Remaining":
218
                    z_order = 0
219
                else:
220
                    z_order = z_order_index
221
                    z_order_index += 1
222
                ordered_values = values[ordered_indices]
223
                ordered_colours = colours[ordered_indices]
224
                axis.scatter(
225
                    ordered_values[:, 0],
226
                    ordered_values[:, 1],
227
                    c=ordered_colours,
228
                    label=label,
229
                    alpha=alpha,
230
                    zorder=z_order
231
                )
232
233
                handles, labels = axis.get_legend_handles_labels()
234
                labels, handles = zip(*sorted(
235
                    zip(labels, handles),
236
                    key=lambda t: label_sorter(t[0]) if label_sorter else None
237
                ))
238
                axis.legend(
239
                    handles,
240
                    labels,
241
                    bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95),
242
                    loc="lower left",
243
                    ncol=2,
244
                    mode="expand",
245
                    borderaxespad=0.
246
                )
247
248
    elif colour_coding == "count_sum":
249
250
        n = colouring_data_set.count_sum[shuffled_indices].flatten()
251
        scatter_plot = axis.scatter(
252
            values[:, 0],
253
            values[:, 1],
254
            c=n,
255
            cmap=colour_map,
256
            alpha=alpha
257
        )
258
        colour_bar = figure.colorbar(scatter_plot)
259
        colour_bar.outline.set_linewidth(0)
260
        colour_bar.set_label("Total number of {}s per {}".format(
261
            colouring_data_set.terms["item"],
262
            colouring_data_set.terms["example"]
263
        ))
264
265
    elif colour_coding == "feature":
266
        if feature_index is None:
267
            raise ValueError("Feature number not given.")
268
        if feature_index > colouring_data_set.number_of_features:
269
            raise ValueError("Feature number higher than number of features.")
270
271
        feature_name = colouring_data_set.feature_names[feature_index]
272
        figure_name += "-{}".format(normalise_string(feature_name))
273
274
        f = colouring_data_set.values[shuffled_indices, feature_index]
275
        if scipy.sparse.issparse(f):
276
            f = f.A
277
        f = f.squeeze()
278
279
        scatter_plot = axis.scatter(
280
            values[:, 0],
281
            values[:, 1],
282
            c=f,
283
            cmap=colour_map,
284
            alpha=alpha
285
        )
286
        colour_bar = figure.colorbar(scatter_plot)
287
        colour_bar.outline.set_linewidth(0)
288
        colour_bar.set_label(feature_name)
289
290
    elif colour_coding is None:
291
        axis.scatter(
292
            values[:, 0], values[:, 1], c="k",
293
            alpha=alpha, edgecolors="none")
294
295
    else:
296
        raise ValueError(
297
            "Colour coding `{}` not found.".format(colour_coding))
298
299
    if centroids:
300
        prior_centroids = centroids["prior"]
301
302
        if prior_centroids:
303
            n_centroids = prior_centroids["probabilities"].shape[0]
304
        else:
305
            n_centroids = 0
306
307
        if n_centroids > 1:
308
            centroids_palette = style.darker_palette(n_centroids)
309
            classes = numpy.arange(n_centroids)
310
311
            means = prior_centroids["means"]
312
            covariance_matrices = prior_centroids["covariance_matrices"]
313
314
            for k in range(n_centroids):
315
                axis.scatter(
316
                    means[k, 0],
317
                    means[k, 1],
318
                    s=60,
319
                    marker="x",
320
                    color="black",
321
                    linewidth=3
322
                )
323
                axis.scatter(
324
                    means[k, 0],
325
                    means[k, 1],
326
                    marker="x",
327
                    facecolor=centroids_palette[k],
328
                    edgecolors="black"
329
                )
330
                ellipse_fill, ellipse_edge = _covariance_matrix_as_ellipse(
331
                    covariance_matrices[k],
332
                    means[k],
333
                    colour=centroids_palette[k]
334
                )
335
                axis.add_patch(ellipse_edge)
336
                axis.add_patch(ellipse_fill)
337
338
    if sampled_values is not None:
339
340
        sampled_values = sampled_values.copy()[:, :2]
341
        if scipy.sparse.issparse(sampled_values):
342
            sampled_values = sampled_values.A
343
344
        sample_colour_map = seaborn.blend_palette(
345
            ("white", "purple"), as_cmap=True)
346
347
        x_limits = axis.get_xlim()
348
        y_limits = axis.get_ylim()
349
350
        axis.hexbin(
351
            sampled_values[:, 0], sampled_values[:, 1],
352
            gridsize=75,
353
            cmap=sample_colour_map,
354
            linewidths=0., edgecolors="none",
355
            zorder=-100
356
        )
357
358
        axis.set_xlim(x_limits)
359
        axis.set_ylim(y_limits)
360
361
    # Reset marker size
362
    style.reset_plot_look()
363
364
    return figure, figure_name
365
366
367
def plot_variable_correlations(values, variable_names=None,
368
                               colouring_data_set=None,
369
                               name="variable_correlations"):
370
371
    figure_name = saving.build_figure_name(name)
372
    n_examples, n_features = values.shape
373
374
    random_state = numpy.random.RandomState(117)
375
    shuffled_indices = random_state.permutation(n_examples)
376
    values = values[shuffled_indices]
377
378
    if colouring_data_set:
379
        labels = colouring_data_set.labels
380
        class_names = colouring_data_set.class_names
381
        number_of_classes = colouring_data_set.number_of_classes
382
        class_palette = colouring_data_set.class_palette
383
        label_sorter = colouring_data_set.label_sorter
384
385
        if not class_palette:
386
            index_palette = style.lighter_palette(number_of_classes)
387
            class_palette = {
388
                class_name: index_palette[i] for i, class_name in
389
                enumerate(sorted(class_names, key=label_sorter))
390
            }
391
392
        labels = labels[shuffled_indices]
393
394
        colours = []
395
396
        for label in labels:
397
            colour = class_palette[label]
398
            colours.append(colour)
399
400
    else:
401
        colours = style.NEUTRAL_COLOUR
402
403
    figure, axes = pyplot.subplots(
404
        nrows=n_features,
405
        ncols=n_features,
406
        figsize=[1.5 * n_features] * 2
407
    )
408
409
    for i in range(n_features):
410
        for j in range(n_features):
411
            axes[i, j].scatter(values[:, i], values[:, j], c=colours, s=1)
412
413
            axes[i, j].set_xticks([])
414
            axes[i, j].set_yticks([])
415
416
            if i == n_features - 1:
417
                axes[i, j].set_xlabel(variable_names[j])
418
419
        axes[i, 0].set_ylabel(variable_names[i])
420
421
    return figure, figure_name
422
423
424
def plot_variable_label_correlations(variable_vector, variable_name,
425
                                     colouring_data_set,
426
                                     name="variable_label_correlations"):
427
428
    figure_name = saving.build_figure_name(name)
429
    n_examples = variable_vector.shape[0]
430
431
    class_names_to_class_ids = numpy.vectorize(
432
        lambda class_name:
433
        colouring_data_set.class_name_to_class_id[class_name]
434
    )
435
    class_ids_to_class_names = numpy.vectorize(
436
        lambda class_name:
437
        colouring_data_set.class_id_to_class_name[class_name]
438
    )
439
440
    labels = colouring_data_set.labels
441
    class_names = colouring_data_set.class_names
442
    number_of_classes = colouring_data_set.number_of_classes
443
    class_palette = colouring_data_set.class_palette
444
    label_sorter = colouring_data_set.label_sorter
445
446
    if not class_palette:
447
        index_palette = style.lighter_palette(number_of_classes)
448
        class_palette = {
449
            class_name: index_palette[i] for i, class_name in
450
            enumerate(sorted(class_names, key=label_sorter))
451
        }
452
453
    random_state = numpy.random.RandomState(117)
454
    shuffled_indices = random_state.permutation(n_examples)
455
    variable_vector = variable_vector[shuffled_indices]
456
457
    labels = labels[shuffled_indices]
458
    label_ids = numpy.expand_dims(class_names_to_class_ids(labels), axis=-1)
459
    colours = [class_palette[label] for label in labels]
460
461
    unique_class_ids = numpy.unique(label_ids)
462
    unique_class_names = class_ids_to_class_names(unique_class_ids)
463
464
    figure = pyplot.figure()
465
    axis = figure.add_subplot(1, 1, 1)
466
    seaborn.despine()
467
468
    axis.scatter(variable_vector, label_ids, c=colours, s=1)
469
470
    axis.set_yticks(unique_class_ids)
471
    axis.set_yticklabels(unique_class_names)
472
473
    axis.set_xlabel(variable_name)
474
    axis.set_ylabel(capitalise_string(colouring_data_set.terms["class"]))
475
476
    return figure, figure_name