--- a +++ b/scvae/analyses/figures/scatter.py @@ -0,0 +1,476 @@ +# ======================================================================== # +# +# Copyright (c) 2017 - 2020 scVAE authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ======================================================================== # + +import numpy +import scipy +import seaborn +from matplotlib import pyplot + +from scvae.analyses.figures import saving, style +from scvae.analyses.figures.utilities import _covariance_matrix_as_ellipse +from scvae.utilities import normalise_string, capitalise_string + + +def plot_values(values, colour_coding=None, colouring_data_set=None, + centroids=None, sampled_values=None, class_name=None, + feature_index=None, figure_labels=None, example_tag=None, + name="scatter"): + + figure_name = name + + if figure_labels: + title = figure_labels.get("title") + x_label = figure_labels.get("x label") + y_label = figure_labels.get("y label") + else: + title = "none" + x_label = "$x$" + y_label = "$y$" + + if not title: + title = "none" + + figure_name += "-" + normalise_string(title) + + if colour_coding: + colour_coding = normalise_string(colour_coding) + figure_name += "-" + colour_coding + if "predicted" in colour_coding: + if colouring_data_set.prediction_specifications: + figure_name += "-" + ( + colouring_data_set.prediction_specifications.name) + else: + figure_name += "unknown_prediction_method" + if colouring_data_set is None: + raise ValueError("Colouring data set not given.") + + if sampled_values is not None: + figure_name += "-samples" + + values = values.copy()[:, :2] + if scipy.sparse.issparse(values): + values = values.A + + # Randomise examples in values to remove any prior order + n_examples, __ = values.shape + random_state = numpy.random.RandomState(117) + shuffled_indices = random_state.permutation(n_examples) + values = values[shuffled_indices] + + # Adjust marker size based on number of examples + style._adjust_marker_size_for_scatter_plots(n_examples) + + figure = pyplot.figure() + axis = figure.add_subplot(1, 1, 1) + seaborn.despine() + + axis.set_xlabel(x_label) + axis.set_ylabel(y_label) + + colour_map = seaborn.dark_palette(style.STANDARD_PALETTE[0], as_cmap=True) + + alpha = 1 + if sampled_values is not None: + alpha = 0.5 + + if colour_coding and ( + "labels" in colour_coding + or "ids" in colour_coding + or "class" in colour_coding + or colour_coding == "batches"): + + if colour_coding == "predicted_cluster_ids": + labels = colouring_data_set.predicted_cluster_ids + class_names = numpy.unique(labels).tolist() + number_of_classes = len(class_names) + class_palette = None + label_sorter = None + elif colour_coding == "predicted_labels": + labels = colouring_data_set.predicted_labels + class_names = colouring_data_set.predicted_class_names + number_of_classes = colouring_data_set.number_of_predicted_classes + class_palette = colouring_data_set.predicted_class_palette + label_sorter = colouring_data_set.predicted_label_sorter + elif colour_coding == "predicted_superset_labels": + labels = colouring_data_set.predicted_superset_labels + class_names = colouring_data_set.predicted_superset_class_names + number_of_classes = ( + colouring_data_set.number_of_predicted_superset_classes) + class_palette = colouring_data_set.predicted_superset_class_palette + label_sorter = colouring_data_set.predicted_superset_label_sorter + elif "superset" in colour_coding: + labels = colouring_data_set.superset_labels + class_names = colouring_data_set.superset_class_names + number_of_classes = colouring_data_set.number_of_superset_classes + class_palette = colouring_data_set.superset_class_palette + label_sorter = colouring_data_set.superset_label_sorter + elif colour_coding == "batches": + labels = colouring_data_set.batch_indices.flatten() + class_names = colouring_data_set.batch_names + number_of_classes = colouring_data_set.number_of_batches + class_palette = None + label_sorter = None + else: + labels = colouring_data_set.labels + class_names = colouring_data_set.class_names + number_of_classes = colouring_data_set.number_of_classes + class_palette = colouring_data_set.class_palette + label_sorter = colouring_data_set.label_sorter + + if not class_palette: + index_palette = style.lighter_palette(number_of_classes) + class_palette = { + class_name: index_palette[i] for i, class_name in + enumerate(sorted(class_names, key=label_sorter)) + } + + # Examples are shuffled, so should their labels be + labels = labels[shuffled_indices] + + if ("labels" in colour_coding or "ids" in colour_coding + or colour_coding == "batches"): + colours = [] + classes = set() + + for i, label in enumerate(labels): + colour = class_palette[label] + colours.append(colour) + + # Plot one example for each class to add labels + if label not in classes: + classes.add(label) + axis.scatter( + values[i, 0], + values[i, 1], + color=colour, + label=label, + alpha=alpha + ) + + axis.scatter(values[:, 0], values[:, 1], c=colours, alpha=alpha) + + class_handles, class_labels = axis.get_legend_handles_labels() + + if class_labels: + class_labels, class_handles = zip(*sorted( + zip(class_labels, class_handles), + key=( + lambda t: label_sorter(t[0])) if label_sorter else None + )) + class_label_maximum_width = max(map(len, class_labels)) + if class_label_maximum_width <= 5 and number_of_classes <= 20: + axis.legend( + class_handles, class_labels, + loc="best" + ) + else: + if number_of_classes <= 20: + class_label_columns = 2 + else: + class_label_columns = 3 + axis.legend( + class_handles, + class_labels, + bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95), + loc="lower left", + ncol=class_label_columns, + mode="expand", + borderaxespad=0., + ) + + elif "class" in colour_coding: + colours = [] + figure_name += "-" + normalise_string(str(class_name)) + ordered_indices_set = { + str(class_name): [], + "Remaining": [] + } + + for i, label in enumerate(labels): + if label == class_name: + colour = class_palette[label] + ordered_indices_set[str(class_name)].append(i) + else: + colour = style.NEUTRAL_COLOUR + ordered_indices_set["Remaining"].append(i) + colours.append(colour) + + colours = numpy.array(colours) + + z_order_index = 1 + for label, ordered_indices in sorted(ordered_indices_set.items()): + if label == "Remaining": + z_order = 0 + else: + z_order = z_order_index + z_order_index += 1 + ordered_values = values[ordered_indices] + ordered_colours = colours[ordered_indices] + axis.scatter( + ordered_values[:, 0], + ordered_values[:, 1], + c=ordered_colours, + label=label, + alpha=alpha, + zorder=z_order + ) + + handles, labels = axis.get_legend_handles_labels() + labels, handles = zip(*sorted( + zip(labels, handles), + key=lambda t: label_sorter(t[0]) if label_sorter else None + )) + axis.legend( + handles, + labels, + bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95), + loc="lower left", + ncol=2, + mode="expand", + borderaxespad=0. + ) + + elif colour_coding == "count_sum": + + n = colouring_data_set.count_sum[shuffled_indices].flatten() + scatter_plot = axis.scatter( + values[:, 0], + values[:, 1], + c=n, + cmap=colour_map, + alpha=alpha + ) + colour_bar = figure.colorbar(scatter_plot) + colour_bar.outline.set_linewidth(0) + colour_bar.set_label("Total number of {}s per {}".format( + colouring_data_set.terms["item"], + colouring_data_set.terms["example"] + )) + + elif colour_coding == "feature": + if feature_index is None: + raise ValueError("Feature number not given.") + if feature_index > colouring_data_set.number_of_features: + raise ValueError("Feature number higher than number of features.") + + feature_name = colouring_data_set.feature_names[feature_index] + figure_name += "-{}".format(normalise_string(feature_name)) + + f = colouring_data_set.values[shuffled_indices, feature_index] + if scipy.sparse.issparse(f): + f = f.A + f = f.squeeze() + + scatter_plot = axis.scatter( + values[:, 0], + values[:, 1], + c=f, + cmap=colour_map, + alpha=alpha + ) + colour_bar = figure.colorbar(scatter_plot) + colour_bar.outline.set_linewidth(0) + colour_bar.set_label(feature_name) + + elif colour_coding is None: + axis.scatter( + values[:, 0], values[:, 1], c="k", + alpha=alpha, edgecolors="none") + + else: + raise ValueError( + "Colour coding `{}` not found.".format(colour_coding)) + + if centroids: + prior_centroids = centroids["prior"] + + if prior_centroids: + n_centroids = prior_centroids["probabilities"].shape[0] + else: + n_centroids = 0 + + if n_centroids > 1: + centroids_palette = style.darker_palette(n_centroids) + classes = numpy.arange(n_centroids) + + means = prior_centroids["means"] + covariance_matrices = prior_centroids["covariance_matrices"] + + for k in range(n_centroids): + axis.scatter( + means[k, 0], + means[k, 1], + s=60, + marker="x", + color="black", + linewidth=3 + ) + axis.scatter( + means[k, 0], + means[k, 1], + marker="x", + facecolor=centroids_palette[k], + edgecolors="black" + ) + ellipse_fill, ellipse_edge = _covariance_matrix_as_ellipse( + covariance_matrices[k], + means[k], + colour=centroids_palette[k] + ) + axis.add_patch(ellipse_edge) + axis.add_patch(ellipse_fill) + + if sampled_values is not None: + + sampled_values = sampled_values.copy()[:, :2] + if scipy.sparse.issparse(sampled_values): + sampled_values = sampled_values.A + + sample_colour_map = seaborn.blend_palette( + ("white", "purple"), as_cmap=True) + + x_limits = axis.get_xlim() + y_limits = axis.get_ylim() + + axis.hexbin( + sampled_values[:, 0], sampled_values[:, 1], + gridsize=75, + cmap=sample_colour_map, + linewidths=0., edgecolors="none", + zorder=-100 + ) + + axis.set_xlim(x_limits) + axis.set_ylim(y_limits) + + # Reset marker size + style.reset_plot_look() + + return figure, figure_name + + +def plot_variable_correlations(values, variable_names=None, + colouring_data_set=None, + name="variable_correlations"): + + figure_name = saving.build_figure_name(name) + n_examples, n_features = values.shape + + random_state = numpy.random.RandomState(117) + shuffled_indices = random_state.permutation(n_examples) + values = values[shuffled_indices] + + if colouring_data_set: + labels = colouring_data_set.labels + class_names = colouring_data_set.class_names + number_of_classes = colouring_data_set.number_of_classes + class_palette = colouring_data_set.class_palette + label_sorter = colouring_data_set.label_sorter + + if not class_palette: + index_palette = style.lighter_palette(number_of_classes) + class_palette = { + class_name: index_palette[i] for i, class_name in + enumerate(sorted(class_names, key=label_sorter)) + } + + labels = labels[shuffled_indices] + + colours = [] + + for label in labels: + colour = class_palette[label] + colours.append(colour) + + else: + colours = style.NEUTRAL_COLOUR + + figure, axes = pyplot.subplots( + nrows=n_features, + ncols=n_features, + figsize=[1.5 * n_features] * 2 + ) + + for i in range(n_features): + for j in range(n_features): + axes[i, j].scatter(values[:, i], values[:, j], c=colours, s=1) + + axes[i, j].set_xticks([]) + axes[i, j].set_yticks([]) + + if i == n_features - 1: + axes[i, j].set_xlabel(variable_names[j]) + + axes[i, 0].set_ylabel(variable_names[i]) + + return figure, figure_name + + +def plot_variable_label_correlations(variable_vector, variable_name, + colouring_data_set, + name="variable_label_correlations"): + + figure_name = saving.build_figure_name(name) + n_examples = variable_vector.shape[0] + + class_names_to_class_ids = numpy.vectorize( + lambda class_name: + colouring_data_set.class_name_to_class_id[class_name] + ) + class_ids_to_class_names = numpy.vectorize( + lambda class_name: + colouring_data_set.class_id_to_class_name[class_name] + ) + + labels = colouring_data_set.labels + class_names = colouring_data_set.class_names + number_of_classes = colouring_data_set.number_of_classes + class_palette = colouring_data_set.class_palette + label_sorter = colouring_data_set.label_sorter + + if not class_palette: + index_palette = style.lighter_palette(number_of_classes) + class_palette = { + class_name: index_palette[i] for i, class_name in + enumerate(sorted(class_names, key=label_sorter)) + } + + random_state = numpy.random.RandomState(117) + shuffled_indices = random_state.permutation(n_examples) + variable_vector = variable_vector[shuffled_indices] + + labels = labels[shuffled_indices] + label_ids = numpy.expand_dims(class_names_to_class_ids(labels), axis=-1) + colours = [class_palette[label] for label in labels] + + unique_class_ids = numpy.unique(label_ids) + unique_class_names = class_ids_to_class_names(unique_class_ids) + + figure = pyplot.figure() + axis = figure.add_subplot(1, 1, 1) + seaborn.despine() + + axis.scatter(variable_vector, label_ids, c=colours, s=1) + + axis.set_yticks(unique_class_ids) + axis.set_yticklabels(unique_class_names) + + axis.set_xlabel(variable_name) + axis.set_ylabel(capitalise_string(colouring_data_set.terms["class"])) + + return figure, figure_name