a b/src/move/visualization/dataset_distributions.py
1
__all__ = ["plot_value_distributions"]
2
3
import matplotlib
4
import matplotlib.figure
5
import matplotlib.pyplot as plt
6
import networkx as nx
7
import numpy as np
8
9
from move.core.typing import FloatArray
10
from move.visualization.style import (
11
    DEFAULT_DIVERGING_PALETTE,
12
    DEFAULT_PLOT_STYLE,
13
    style_settings,
14
)
15
16
17
def plot_value_distributions(
18
    feature_values: FloatArray,
19
    style: str = "fast",
20
    nbins: int = 100,
21
    colormap: str = DEFAULT_DIVERGING_PALETTE,
22
) -> matplotlib.figure.Figure:
23
    """
24
    Given a certain dataset, plot its distribution of values.
25
26
27
    Args:
28
        feature_values:
29
            Values of the features, a 2D array (`num_samples` x `num_features`).
30
        style:
31
            Name of style to apply to the plot.
32
        colormap:
33
            Name of colormap to apply to the colorbar.
34
35
    Returns:
36
        Figure
37
    """
38
    vmin, vmax = np.nanmin(feature_values), np.nanmax(feature_values)
39
    with style_settings(style):
40
        fig = plt.figure(layout="constrained", figsize=(7, 7))
41
        ax = fig.add_subplot(projection="3d")
42
        x_val = np.linspace(vmin, vmax, nbins)
43
        y_val = np.arange(np.shape(feature_values)[1])
44
        x_val, y_val = np.meshgrid(x_val, y_val)
45
46
        histogram = []
47
        for i in range(np.shape(feature_values)[1]):
48
            feat_i_list = feature_values[:, i]
49
            feat_hist, feat_bin_edges = np.histogram(
50
                feat_i_list, bins=nbins, range=(vmin, vmax)
51
            )
52
            histogram.append(feat_hist)
53
54
        ax.plot_surface(x_val, y_val, np.array(histogram), cmap=colormap)
55
        ax.set_xlabel("Feature value")
56
        ax.set_ylabel("Feature ID number")
57
        ax.set_zlabel("Frequency")
58
        # ax.legend()
59
    return fig
60
61
62
def plot_reconstruction_diff(
63
    diff_array: FloatArray,
64
    vmin=None,
65
    vmax=None,
66
    style: str = DEFAULT_PLOT_STYLE,
67
    colormap: str = DEFAULT_DIVERGING_PALETTE,
68
) -> matplotlib.figure.Figure:
69
    """
70
    Plot the reconstruction differences as a heatmap.
71
    """
72
    with style_settings(style):
73
        if vmin is None:
74
            vmin = np.min(diff_array)
75
        elif vmax is None:
76
            vmax = np.max(diff_array)
77
        fig = plt.figure(layout="constrained", figsize=(7, 7))
78
        plt.imshow(diff_array, cmap=colormap, vmin=vmin, vmax=vmax)
79
        plt.xlabel("Feature")
80
        plt.ylabel("Sample")
81
        plt.colorbar()
82
83
    return fig
84
85
86
def plot_feature_association_graph(
87
    association_df, output_path, layout="circular", style: str = DEFAULT_PLOT_STYLE
88
) -> matplotlib.figure.Figure:
89
    """
90
    This function plots a graph where each node corresponds to a feature and the edges
91
    represent the associations between features. Edge width represents the probability
92
    of said association, not the association's effect size.
93
94
    Input:
95
        association_df: pandas dataframe containing the following columns:
96
                            - feature_a: source node
97
                            - feature_b: target node
98
                            - p_value/bayes_score: edge weight
99
        output_path: Path object where the picture will be stored.
100
101
    Output:
102
        Feature_association_graph.png: picture of the graph
103
104
    """
105
106
    if "p_value" in association_df.columns:
107
        association_df["weight"] = 1 - association_df["p_value"]
108
109
    elif "proba" in association_df.columns:
110
        association_df["weight"] = association_df["proba"]
111
112
    elif "ks_distance" in association_df.columns:
113
        association_df["weight"] = association_df["ks_distance"]
114
115
    with style_settings(style):
116
        fig = plt.figure(figsize=(45, 45))
117
        G = nx.from_pandas_edgelist(
118
            association_df,
119
            source="feature_a_name",
120
            target="feature_b_name",
121
            edge_attr="weight",
122
        )
123
124
        nodes = list(G.nodes)
125
126
        datasets = association_df["feature_b_dataset"].unique()
127
        color_map = {
128
            dataset: (np.random.uniform(), np.random.uniform(), np.random.uniform())
129
            for dataset in datasets
130
        }
131
        node_dataset_map = {
132
            target_feature: dataset
133
            for (target_feature, dataset) in zip(
134
                association_df["feature_b_name"], association_df["feature_b_dataset"]
135
            )
136
        }
137
138
        if layout == "spring":
139
            pos = nx.spring_layout(G)
140
            with_labels = True
141
        elif layout == "circular":
142
            pos = nx.circular_layout(G)
143
            _ = [
144
                plt.text(
145
                    pos[node][0],
146
                    pos[node][1],
147
                    nodes[i],
148
                    rotation=(i / float(len(nodes))) * 360,
149
                    fontsize=10,
150
                    horizontalalignment="center",
151
                    verticalalignment="center",
152
                )
153
                for i, node in enumerate(nodes)
154
            ]
155
            with_labels = False
156
157
        else:
158
            raise ValueError(
159
                "Graph layout (layout argument) must be either 'circular' or 'spring'."
160
            )
161
162
        nx.draw(
163
            G,
164
            pos=pos,
165
            with_labels=with_labels,
166
            node_size=2000,
167
            node_color=[
168
                (
169
                    color_map[node_dataset_map[feature]]
170
                    if feature in node_dataset_map.keys()
171
                    else "white"
172
                )
173
                for feature in G.nodes
174
            ],
175
            edge_color=list(nx.get_edge_attributes(G, "weight").values()),
176
            font_color="black",
177
            font_size=10,
178
            edge_cmap=matplotlib.colormaps["Purples"],
179
            connectionstyle="arc3, rad=1",
180
        )
181
182
        plt.tight_layout()
183
        fig.savefig(
184
            output_path / f"Feature_association_graph_{layout}.png", format="png"
185
        )
186
    return fig
187
188
189
def plot_feature_mean_median(
190
    array: FloatArray, axis=0, style: str = DEFAULT_PLOT_STYLE
191
) -> matplotlib.figure.Figure:
192
    """
193
    Plot feature values together with the mean, median, min and max values
194
    at each array position.
195
    """
196
    with style_settings(style):
197
        fig = plt.figure(figsize=(15, 3))
198
        y = np.mean(array, axis=axis)
199
        y_2 = np.median(array, axis=axis)
200
        y_3 = np.max(array, axis=axis)
201
        y_4 = np.min(array, axis=axis)
202
        plt.plot(np.arange(len(y)), y, "bo", label="mean")
203
        plt.plot(np.arange(len(y_2)), y_2, "ro", label="median")
204
        plt.plot(np.arange(len(y_3)), y_3, "go", label="max")
205
        plt.plot(np.arange(len(y_4)), y_4, "yo", label="min")
206
        plt.legend()
207
        plt.xlabel("feature")
208
        plt.ylabel("mean/median/min/max")
209
210
    return fig
211
212
213
def plot_reconstruction_movement(
214
    baseline_recon: FloatArray,
215
    perturb_recon: FloatArray,
216
    k: int,
217
    style: str = DEFAULT_PLOT_STYLE,
218
) -> matplotlib.figure.Figure:
219
    """
220
    Plot, for each sample, the change in value from the unperturbed reconstruction to
221
    the perturbed reconstruction. Blue lines are left/negative shifts,
222
    red lines are right/positive shifts.
223
224
    Args:
225
        baseline_recon: baseline reconstruction array with s samples
226
                        and k features (s,k).
227
        perturb_recon:  perturbed
228
        k: feature index. The shift (movement) of this feature's reconstruction
229
                          will be plotted for all samples s.
230
    """
231
    with style_settings(style):
232
        # Feature changes
233
        fig = plt.figure(figsize=(25, 25))
234
        for s in range(np.shape(baseline_recon)[0]):
235
            plt.arrow(
236
                baseline_recon[s, k],
237
                s / 100,
238
                perturb_recon[s, k],
239
                0,
240
                length_includes_head=True,
241
                color=["r" if baseline_recon[s, k] < perturb_recon[s, k] else "b"][0],
242
            )
243
        plt.ylabel("Sample (e2)", size=40)
244
        plt.xlabel("Feature_value", size=40)
245
    return fig
246
247
248
def plot_cumulative_distributions(
249
    edges: FloatArray,
250
    hist_base: FloatArray,
251
    hist_pert: FloatArray,
252
    title: str,
253
    style: str = DEFAULT_PLOT_STYLE,
254
) -> matplotlib.figure.Figure:
255
    """
256
    Plot the cumulative distribution of the histograms for the baseline
257
    and perturbed reconstructions. This is useful to visually assess the
258
    magnitude of the shift corresponding to the KS score.
259
260
    """
261
262
    with style_settings(style):
263
        # Cumulative distribution:
264
        fig = plt.figure(figsize=(7, 7))
265
        plt.plot(
266
            (edges[:-1] + edges[1:]) / 2,
267
            np.cumsum(hist_base),
268
            color="blue",
269
            label="baseline",
270
            alpha=0.5,
271
        )
272
        plt.plot(
273
            (edges[:-1] + edges[1:]) / 2,
274
            np.cumsum(hist_pert),
275
            color="red",
276
            label="Perturbed",
277
            alpha=0.5,
278
        )
279
280
        plt.title(f"{title}.png")
281
        plt.xlabel("Feature value")
282
        plt.ylabel("Cumulative distribution")
283
        plt.legend()
284
285
    return fig
286
287
288
def plot_correlations(
289
    x: FloatArray,
290
    y: FloatArray,
291
    x_pol: FloatArray,
292
    y_pol: FloatArray,
293
    a2: float,
294
    a1: float,
295
    a: float,
296
    k: int,
297
    style: str = DEFAULT_PLOT_STYLE,
298
) -> matplotlib.figure.Figure:
299
    """
300
    Plot y vs x and the corresponding polynomial fit.
301
    """
302
    with style_settings(style):
303
        # Plot correlations
304
        fig = plt.figure(figsize=(7, 7))
305
        plt.plot(x, y, marker=".", lw=0, markersize=1, color="red")
306
        plt.plot(
307
            x_pol,
308
            y_pol,
309
            color="blue",
310
            label="{0:.2f}x^2 {1:.2f}x {2:.2f}".format(a2, a1, a),
311
            lw=1,
312
        )
313
        plt.plot(x_pol, x_pol, lw=1, color="k")
314
        plt.xlabel(f"Feature {k} baseline values ")
315
        plt.ylabel(f"Feature {k} baseline  value reconstruction")
316
        plt.legend()
317
318
    return fig
319
320
321
def get_2nd_order_polynomial(x_array, y_array, n_points=100):
322
    """
323
    Given a set of x an y values, find the 2nd oder polynomial fitting best the data.
324
    Returns:
325
        x_pol: x coordinates for the polynomial function evaluation.
326
        y_pol: y coordinates for the polynomial function evaluation.
327
    """
328
    a2, a1, a = np.polyfit(x_array, y_array, deg=2)
329
330
    x_pol = np.linspace(np.min(x_array), np.max(x_array), n_points)
331
    y_pol = a2 * x_pol**2 + a1 * x_pol + a
332
333
    return x_pol, y_pol, (a2, a1, a)