Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery

Switch to side-by-side view

--- a
+++ b/src/move/visualization/dataset_distributions.py
@@ -0,0 +1,333 @@
+__all__ = ["plot_value_distributions"]
+
+import matplotlib
+import matplotlib.figure
+import matplotlib.pyplot as plt
+import networkx as nx
+import numpy as np
+
+from move.core.typing import FloatArray
+from move.visualization.style import (
+    DEFAULT_DIVERGING_PALETTE,
+    DEFAULT_PLOT_STYLE,
+    style_settings,
+)
+
+
+def plot_value_distributions(
+    feature_values: FloatArray,
+    style: str = "fast",
+    nbins: int = 100,
+    colormap: str = DEFAULT_DIVERGING_PALETTE,
+) -> matplotlib.figure.Figure:
+    """
+    Given a certain dataset, plot its distribution of values.
+
+
+    Args:
+        feature_values:
+            Values of the features, a 2D array (`num_samples` x `num_features`).
+        style:
+            Name of style to apply to the plot.
+        colormap:
+            Name of colormap to apply to the colorbar.
+
+    Returns:
+        Figure
+    """
+    vmin, vmax = np.nanmin(feature_values), np.nanmax(feature_values)
+    with style_settings(style):
+        fig = plt.figure(layout="constrained", figsize=(7, 7))
+        ax = fig.add_subplot(projection="3d")
+        x_val = np.linspace(vmin, vmax, nbins)
+        y_val = np.arange(np.shape(feature_values)[1])
+        x_val, y_val = np.meshgrid(x_val, y_val)
+
+        histogram = []
+        for i in range(np.shape(feature_values)[1]):
+            feat_i_list = feature_values[:, i]
+            feat_hist, feat_bin_edges = np.histogram(
+                feat_i_list, bins=nbins, range=(vmin, vmax)
+            )
+            histogram.append(feat_hist)
+
+        ax.plot_surface(x_val, y_val, np.array(histogram), cmap=colormap)
+        ax.set_xlabel("Feature value")
+        ax.set_ylabel("Feature ID number")
+        ax.set_zlabel("Frequency")
+        # ax.legend()
+    return fig
+
+
+def plot_reconstruction_diff(
+    diff_array: FloatArray,
+    vmin=None,
+    vmax=None,
+    style: str = DEFAULT_PLOT_STYLE,
+    colormap: str = DEFAULT_DIVERGING_PALETTE,
+) -> matplotlib.figure.Figure:
+    """
+    Plot the reconstruction differences as a heatmap.
+    """
+    with style_settings(style):
+        if vmin is None:
+            vmin = np.min(diff_array)
+        elif vmax is None:
+            vmax = np.max(diff_array)
+        fig = plt.figure(layout="constrained", figsize=(7, 7))
+        plt.imshow(diff_array, cmap=colormap, vmin=vmin, vmax=vmax)
+        plt.xlabel("Feature")
+        plt.ylabel("Sample")
+        plt.colorbar()
+
+    return fig
+
+
+def plot_feature_association_graph(
+    association_df, output_path, layout="circular", style: str = DEFAULT_PLOT_STYLE
+) -> matplotlib.figure.Figure:
+    """
+    This function plots a graph where each node corresponds to a feature and the edges
+    represent the associations between features. Edge width represents the probability
+    of said association, not the association's effect size.
+
+    Input:
+        association_df: pandas dataframe containing the following columns:
+                            - feature_a: source node
+                            - feature_b: target node
+                            - p_value/bayes_score: edge weight
+        output_path: Path object where the picture will be stored.
+
+    Output:
+        Feature_association_graph.png: picture of the graph
+
+    """
+
+    if "p_value" in association_df.columns:
+        association_df["weight"] = 1 - association_df["p_value"]
+
+    elif "proba" in association_df.columns:
+        association_df["weight"] = association_df["proba"]
+
+    elif "ks_distance" in association_df.columns:
+        association_df["weight"] = association_df["ks_distance"]
+
+    with style_settings(style):
+        fig = plt.figure(figsize=(45, 45))
+        G = nx.from_pandas_edgelist(
+            association_df,
+            source="feature_a_name",
+            target="feature_b_name",
+            edge_attr="weight",
+        )
+
+        nodes = list(G.nodes)
+
+        datasets = association_df["feature_b_dataset"].unique()
+        color_map = {
+            dataset: (np.random.uniform(), np.random.uniform(), np.random.uniform())
+            for dataset in datasets
+        }
+        node_dataset_map = {
+            target_feature: dataset
+            for (target_feature, dataset) in zip(
+                association_df["feature_b_name"], association_df["feature_b_dataset"]
+            )
+        }
+
+        if layout == "spring":
+            pos = nx.spring_layout(G)
+            with_labels = True
+        elif layout == "circular":
+            pos = nx.circular_layout(G)
+            _ = [
+                plt.text(
+                    pos[node][0],
+                    pos[node][1],
+                    nodes[i],
+                    rotation=(i / float(len(nodes))) * 360,
+                    fontsize=10,
+                    horizontalalignment="center",
+                    verticalalignment="center",
+                )
+                for i, node in enumerate(nodes)
+            ]
+            with_labels = False
+
+        else:
+            raise ValueError(
+                "Graph layout (layout argument) must be either 'circular' or 'spring'."
+            )
+
+        nx.draw(
+            G,
+            pos=pos,
+            with_labels=with_labels,
+            node_size=2000,
+            node_color=[
+                (
+                    color_map[node_dataset_map[feature]]
+                    if feature in node_dataset_map.keys()
+                    else "white"
+                )
+                for feature in G.nodes
+            ],
+            edge_color=list(nx.get_edge_attributes(G, "weight").values()),
+            font_color="black",
+            font_size=10,
+            edge_cmap=matplotlib.colormaps["Purples"],
+            connectionstyle="arc3, rad=1",
+        )
+
+        plt.tight_layout()
+        fig.savefig(
+            output_path / f"Feature_association_graph_{layout}.png", format="png"
+        )
+    return fig
+
+
+def plot_feature_mean_median(
+    array: FloatArray, axis=0, style: str = DEFAULT_PLOT_STYLE
+) -> matplotlib.figure.Figure:
+    """
+    Plot feature values together with the mean, median, min and max values
+    at each array position.
+    """
+    with style_settings(style):
+        fig = plt.figure(figsize=(15, 3))
+        y = np.mean(array, axis=axis)
+        y_2 = np.median(array, axis=axis)
+        y_3 = np.max(array, axis=axis)
+        y_4 = np.min(array, axis=axis)
+        plt.plot(np.arange(len(y)), y, "bo", label="mean")
+        plt.plot(np.arange(len(y_2)), y_2, "ro", label="median")
+        plt.plot(np.arange(len(y_3)), y_3, "go", label="max")
+        plt.plot(np.arange(len(y_4)), y_4, "yo", label="min")
+        plt.legend()
+        plt.xlabel("feature")
+        plt.ylabel("mean/median/min/max")
+
+    return fig
+
+
+def plot_reconstruction_movement(
+    baseline_recon: FloatArray,
+    perturb_recon: FloatArray,
+    k: int,
+    style: str = DEFAULT_PLOT_STYLE,
+) -> matplotlib.figure.Figure:
+    """
+    Plot, for each sample, the change in value from the unperturbed reconstruction to
+    the perturbed reconstruction. Blue lines are left/negative shifts,
+    red lines are right/positive shifts.
+
+    Args:
+        baseline_recon: baseline reconstruction array with s samples
+                        and k features (s,k).
+        perturb_recon:  perturbed
+        k: feature index. The shift (movement) of this feature's reconstruction
+                          will be plotted for all samples s.
+    """
+    with style_settings(style):
+        # Feature changes
+        fig = plt.figure(figsize=(25, 25))
+        for s in range(np.shape(baseline_recon)[0]):
+            plt.arrow(
+                baseline_recon[s, k],
+                s / 100,
+                perturb_recon[s, k],
+                0,
+                length_includes_head=True,
+                color=["r" if baseline_recon[s, k] < perturb_recon[s, k] else "b"][0],
+            )
+        plt.ylabel("Sample (e2)", size=40)
+        plt.xlabel("Feature_value", size=40)
+    return fig
+
+
+def plot_cumulative_distributions(
+    edges: FloatArray,
+    hist_base: FloatArray,
+    hist_pert: FloatArray,
+    title: str,
+    style: str = DEFAULT_PLOT_STYLE,
+) -> matplotlib.figure.Figure:
+    """
+    Plot the cumulative distribution of the histograms for the baseline
+    and perturbed reconstructions. This is useful to visually assess the
+    magnitude of the shift corresponding to the KS score.
+
+    """
+
+    with style_settings(style):
+        # Cumulative distribution:
+        fig = plt.figure(figsize=(7, 7))
+        plt.plot(
+            (edges[:-1] + edges[1:]) / 2,
+            np.cumsum(hist_base),
+            color="blue",
+            label="baseline",
+            alpha=0.5,
+        )
+        plt.plot(
+            (edges[:-1] + edges[1:]) / 2,
+            np.cumsum(hist_pert),
+            color="red",
+            label="Perturbed",
+            alpha=0.5,
+        )
+
+        plt.title(f"{title}.png")
+        plt.xlabel("Feature value")
+        plt.ylabel("Cumulative distribution")
+        plt.legend()
+
+    return fig
+
+
+def plot_correlations(
+    x: FloatArray,
+    y: FloatArray,
+    x_pol: FloatArray,
+    y_pol: FloatArray,
+    a2: float,
+    a1: float,
+    a: float,
+    k: int,
+    style: str = DEFAULT_PLOT_STYLE,
+) -> matplotlib.figure.Figure:
+    """
+    Plot y vs x and the corresponding polynomial fit.
+    """
+    with style_settings(style):
+        # Plot correlations
+        fig = plt.figure(figsize=(7, 7))
+        plt.plot(x, y, marker=".", lw=0, markersize=1, color="red")
+        plt.plot(
+            x_pol,
+            y_pol,
+            color="blue",
+            label="{0:.2f}x^2 {1:.2f}x {2:.2f}".format(a2, a1, a),
+            lw=1,
+        )
+        plt.plot(x_pol, x_pol, lw=1, color="k")
+        plt.xlabel(f"Feature {k} baseline values ")
+        plt.ylabel(f"Feature {k} baseline  value reconstruction")
+        plt.legend()
+
+    return fig
+
+
+def get_2nd_order_polynomial(x_array, y_array, n_points=100):
+    """
+    Given a set of x an y values, find the 2nd oder polynomial fitting best the data.
+    Returns:
+        x_pol: x coordinates for the polynomial function evaluation.
+        y_pol: y coordinates for the polynomial function evaluation.
+    """
+    a2, a1, a = np.polyfit(x_array, y_array, deg=2)
+
+    x_pol = np.linspace(np.min(x_array), np.max(x_array), n_points)
+    y_pol = a2 * x_pol**2 + a1 * x_pol + a
+
+    return x_pol, y_pol, (a2, a1, a)