Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery

Switch to side-by-side view

--- a
+++ b/src/move/visualization/vae_visualization.py
@@ -0,0 +1,268 @@
+__all__ = ["plot_vae"]
+
+from pathlib import Path
+from typing import Optional
+
+import matplotlib
+import matplotlib.cm as cm
+import matplotlib.figure
+import matplotlib.pyplot as plt
+import networkx as nx
+import numpy as np
+import torch
+
+
+def plot_vae(
+    path: Path,
+    savepath: Path,
+    filename: str,
+    title: str,
+    num_input: int,
+    num_hidden: int,
+    num_latent: int,
+    plot_edges=True,
+    input_sample: Optional[torch.Tensor] = None,
+    output_sample: Optional[torch.Tensor] = None,
+    mu: Optional[torch.Tensor] = None,
+    logvar: Optional[torch.Tensor] = None,
+) -> matplotlib.figure.Figure:
+    """
+    This function is aimed to visualize MOVE's architecture.
+
+    Args:
+        path: path where the trained model with defined weights is to be found
+        filename: name of the model
+        title: title of the figure
+        num_input: number of input nodes
+        num_hidden: number of output nodes
+        num_latent: number of latent nodes
+        plot_edges: plot edges, i.e. connections with assigned weights between nodes
+        input_sample: array with input values to fill with a mapped color value
+        output_sample : "      " output " "
+        mu: "                  " mean (latent) "     "
+        logvar: "              " log variance  "     "
+
+    Returns:
+        figure
+
+    Notes:
+        k: input node index
+        j: hidden node index
+        i: latent node index
+    """
+    model_weights = torch.load(path / filename)
+    G = nx.Graph()
+
+    # Position of the layers:
+    layer_distance = 10
+    node_distance = 550
+    latent_node_distance = 550
+    latent_sep = 5 * latent_node_distance
+
+    # Adding nodes to the graph ##############################
+    # Bias nodes
+    G.add_node(
+        "input_bias",
+        pos=(-6 * layer_distance, -3 * node_distance - num_input * node_distance / 2),
+        color=0.0,
+    )
+    G.add_node(
+        "mu_bias",
+        pos=(
+            -3 * layer_distance,
+            (num_hidden + 3) * node_distance
+            - num_hidden * node_distance / 2
+            + latent_sep / 2,
+        ),
+        color=0.0,
+    )
+    G.add_node(
+        "var_bias",
+        pos=(
+            -3 * layer_distance,
+            -3 * node_distance - num_hidden * node_distance / 2 - latent_sep / 2,
+        ),
+        color=0.0,
+    )
+    G.add_node(
+        "sam_bias",
+        pos=(
+            0.5 * layer_distance,
+            -3 * latent_node_distance - num_latent * latent_node_distance / 2,
+        ),
+        color=0.0,
+    )
+    G.add_node(
+        "out_bias",
+        pos=(3 * layer_distance, -3 * node_distance - num_hidden * node_distance / 2),
+        color=0.0,
+    )
+
+    # Actual nodes
+    for k in range(num_input):
+        G.add_node(
+            f"input_{k}",
+            pos=(
+                -6 * layer_distance,
+                k * node_distance - num_input * node_distance / 2,
+            ),
+            color=[input_sample[k] if input_sample is not None else 0.0][0],
+        )
+        G.add_node(
+            f"output_{k}",
+            pos=(6 * layer_distance, k * node_distance - num_input * node_distance / 2),
+            color=[output_sample[k] if output_sample is not None else 0.0][0],
+        )
+    for j in range(num_hidden):
+        G.add_node(
+            f"encoder_hidden_{j}",
+            pos=(
+                -3 * layer_distance,
+                j * node_distance - num_hidden * node_distance / 2,
+            ),
+            color=0.0,
+        )
+        G.add_node(
+            f"decoder_hidden_{j}",
+            pos=(
+                3 * layer_distance,
+                j * node_distance - num_hidden * node_distance / 2,
+            ),
+            color=0.0,
+        )
+    for i in range(num_latent):
+        G.add_node(
+            f"mu_{i}",
+            pos=(0 * layer_distance, i * latent_node_distance + latent_sep / 2),
+            color=[mu[i] if mu is not None else 0.0][0],
+        )
+        G.add_node(
+            f"var_{i}",
+            pos=(0 * layer_distance, -i * latent_node_distance - latent_sep / 2),
+            color=[np.exp(logvar[i] / 2) if logvar is not None else 0.0][0],
+        )
+        G.add_node(
+            f"sam_{i}",
+            pos=(
+                0.5 * layer_distance,
+                i * latent_node_distance - num_latent * latent_node_distance / 2,
+            ),
+            color=0.0,
+        )
+
+    # Adding weights to the graph #########################
+
+    if plot_edges:
+        for layer, values in model_weights.items():
+            if layer == "encoderlayers.0.weight":
+                for k in range(values.shape[1]):  # input
+                    for j in range(values.shape[0]):  # encoder_hidden
+                        G.add_edge(
+                            f"input_{k}",
+                            f"encoder_hidden_{j}",
+                            weight=values.numpy()[j, k],
+                        )
+
+            elif layer == "encoderlayers.0.bias":
+                for j in range(values.shape[0]):  # encoder_hidden
+                    G.add_edge(
+                        "input_bias", f"encoder_hidden_{j}", weight=values.numpy()[j]
+                    )
+
+            elif layer == "mu.weight":
+                for j in range(values.shape[1]):  # encoder hidden
+                    for i in range(values.shape[0]):  # mu
+                        G.add_edge(
+                            f"encoder_hidden_{j}",
+                            f"mu_{i}",
+                            weight=values.numpy()[i, j],
+                        )
+
+            elif layer == "mu.bias":
+                for i in range(values.shape[0]):  # encoder_hidden
+                    G.add_edge("mu_bias", f"mu_{i}", weight=values.numpy()[i])
+
+            elif layer == "var.weight":
+                for j in range(values.shape[1]):  # encoder hidden
+                    for i in range(values.shape[0]):  # var
+                        G.add_edge(
+                            f"encoder_hidden_{j}",
+                            f"var_{i}",
+                            weight=values.numpy()[i, j],
+                        )
+
+            elif layer == "var.bias":
+                for i in range(values.shape[0]):  # encoder_hidden
+                    G.add_edge("var_bias", f"var_{i}", weight=values.numpy()[i])
+
+            # Sampled layer from mu and var:
+            elif layer == "decoderlayers.0.weight":
+                for i in range(values.shape[1]):  # sampled latent
+                    for j in range(values.shape[0]):  # decoder_hidden
+                        G.add_edge(
+                            f"sam_{i}",
+                            f"decoder_hidden_{j}",
+                            weight=values.numpy()[j, i],
+                        )
+
+            # Sampled layer from mu and var:
+            elif layer == "decoderlayers.0.bias":
+                for j in range(values.shape[0]):  # decoder_hidden
+                    G.add_edge(
+                        "sam_bias", f"decoder_hidden_{j}", weight=values.numpy()[j]
+                    )
+
+            elif layer == "out.weight":
+                for j in range(values.shape[1]):  # decoder_hidden
+                    for k in range(values.shape[0]):  # output
+                        G.add_edge(
+                            f"output_{k}",
+                            f"decoder_hidden_{j}",
+                            weight=values.numpy()[k, j],
+                        )
+
+            elif layer == "out.bias":
+                for k in range(values.shape[0]):  # output
+                    G.add_edge("out_bias", f"output_{k}", weight=values.numpy()[k])
+
+    fig = plt.figure(figsize=(60, 60))
+    pos = nx.get_node_attributes(G, "pos")
+    color = list(nx.get_node_attributes(G, "color").values())
+    edge_color = list(nx.get_edge_attributes(G, "weight").values())
+    edge_width = list(nx.get_edge_attributes(G, "weight").values())
+
+    edge_cmap = matplotlib.colormaps["seismic"]
+    node_cmap = matplotlib.colormaps["seismic"]
+
+    abs_max = np.max([abs(np.min(color)), abs(np.max(color))])
+    abs_max_edge = np.max([abs(np.min(edge_color)), abs(np.max(edge_color))])
+
+    _ = cm.ScalarMappable(
+        cmap=node_cmap, norm=matplotlib.colors.Normalize(vmin=-abs_max, vmax=abs_max)
+    )
+    sm_edge = cm.ScalarMappable(
+        cmap=edge_cmap,
+        norm=matplotlib.colors.Normalize(vmin=-abs_max_edge, vmax=abs_max_edge),
+    )
+
+    nx.draw(
+        G,
+        pos=pos,
+        with_labels=True,
+        node_size=100,
+        node_color=color,
+        edge_color=edge_color,
+        width=edge_width,
+        font_color="black",
+        font_size=10,
+        edge_cmap=edge_cmap,
+        cmap=node_cmap,
+        vmin=-abs_max,
+        vmax=abs_max,
+    )
+
+    # plt.colorbar(sm_node, label="Node value", shrink = .2)
+    plt.colorbar(sm_edge, label="Edge value", shrink=0.2)
+    plt.tight_layout()
+    fig.savefig(savepath / f"{title}.png", format="png", dpi=200)
+    return fig