a b/src/move/visualization/vae_visualization.py
1
__all__ = ["plot_vae"]
2
3
from pathlib import Path
4
from typing import Optional
5
6
import matplotlib
7
import matplotlib.cm as cm
8
import matplotlib.figure
9
import matplotlib.pyplot as plt
10
import networkx as nx
11
import numpy as np
12
import torch
13
14
15
def plot_vae(
16
    path: Path,
17
    savepath: Path,
18
    filename: str,
19
    title: str,
20
    num_input: int,
21
    num_hidden: int,
22
    num_latent: int,
23
    plot_edges=True,
24
    input_sample: Optional[torch.Tensor] = None,
25
    output_sample: Optional[torch.Tensor] = None,
26
    mu: Optional[torch.Tensor] = None,
27
    logvar: Optional[torch.Tensor] = None,
28
) -> matplotlib.figure.Figure:
29
    """
30
    This function is aimed to visualize MOVE's architecture.
31
32
    Args:
33
        path: path where the trained model with defined weights is to be found
34
        filename: name of the model
35
        title: title of the figure
36
        num_input: number of input nodes
37
        num_hidden: number of output nodes
38
        num_latent: number of latent nodes
39
        plot_edges: plot edges, i.e. connections with assigned weights between nodes
40
        input_sample: array with input values to fill with a mapped color value
41
        output_sample : "      " output " "
42
        mu: "                  " mean (latent) "     "
43
        logvar: "              " log variance  "     "
44
45
    Returns:
46
        figure
47
48
    Notes:
49
        k: input node index
50
        j: hidden node index
51
        i: latent node index
52
    """
53
    model_weights = torch.load(path / filename)
54
    G = nx.Graph()
55
56
    # Position of the layers:
57
    layer_distance = 10
58
    node_distance = 550
59
    latent_node_distance = 550
60
    latent_sep = 5 * latent_node_distance
61
62
    # Adding nodes to the graph ##############################
63
    # Bias nodes
64
    G.add_node(
65
        "input_bias",
66
        pos=(-6 * layer_distance, -3 * node_distance - num_input * node_distance / 2),
67
        color=0.0,
68
    )
69
    G.add_node(
70
        "mu_bias",
71
        pos=(
72
            -3 * layer_distance,
73
            (num_hidden + 3) * node_distance
74
            - num_hidden * node_distance / 2
75
            + latent_sep / 2,
76
        ),
77
        color=0.0,
78
    )
79
    G.add_node(
80
        "var_bias",
81
        pos=(
82
            -3 * layer_distance,
83
            -3 * node_distance - num_hidden * node_distance / 2 - latent_sep / 2,
84
        ),
85
        color=0.0,
86
    )
87
    G.add_node(
88
        "sam_bias",
89
        pos=(
90
            0.5 * layer_distance,
91
            -3 * latent_node_distance - num_latent * latent_node_distance / 2,
92
        ),
93
        color=0.0,
94
    )
95
    G.add_node(
96
        "out_bias",
97
        pos=(3 * layer_distance, -3 * node_distance - num_hidden * node_distance / 2),
98
        color=0.0,
99
    )
100
101
    # Actual nodes
102
    for k in range(num_input):
103
        G.add_node(
104
            f"input_{k}",
105
            pos=(
106
                -6 * layer_distance,
107
                k * node_distance - num_input * node_distance / 2,
108
            ),
109
            color=[input_sample[k] if input_sample is not None else 0.0][0],
110
        )
111
        G.add_node(
112
            f"output_{k}",
113
            pos=(6 * layer_distance, k * node_distance - num_input * node_distance / 2),
114
            color=[output_sample[k] if output_sample is not None else 0.0][0],
115
        )
116
    for j in range(num_hidden):
117
        G.add_node(
118
            f"encoder_hidden_{j}",
119
            pos=(
120
                -3 * layer_distance,
121
                j * node_distance - num_hidden * node_distance / 2,
122
            ),
123
            color=0.0,
124
        )
125
        G.add_node(
126
            f"decoder_hidden_{j}",
127
            pos=(
128
                3 * layer_distance,
129
                j * node_distance - num_hidden * node_distance / 2,
130
            ),
131
            color=0.0,
132
        )
133
    for i in range(num_latent):
134
        G.add_node(
135
            f"mu_{i}",
136
            pos=(0 * layer_distance, i * latent_node_distance + latent_sep / 2),
137
            color=[mu[i] if mu is not None else 0.0][0],
138
        )
139
        G.add_node(
140
            f"var_{i}",
141
            pos=(0 * layer_distance, -i * latent_node_distance - latent_sep / 2),
142
            color=[np.exp(logvar[i] / 2) if logvar is not None else 0.0][0],
143
        )
144
        G.add_node(
145
            f"sam_{i}",
146
            pos=(
147
                0.5 * layer_distance,
148
                i * latent_node_distance - num_latent * latent_node_distance / 2,
149
            ),
150
            color=0.0,
151
        )
152
153
    # Adding weights to the graph #########################
154
155
    if plot_edges:
156
        for layer, values in model_weights.items():
157
            if layer == "encoderlayers.0.weight":
158
                for k in range(values.shape[1]):  # input
159
                    for j in range(values.shape[0]):  # encoder_hidden
160
                        G.add_edge(
161
                            f"input_{k}",
162
                            f"encoder_hidden_{j}",
163
                            weight=values.numpy()[j, k],
164
                        )
165
166
            elif layer == "encoderlayers.0.bias":
167
                for j in range(values.shape[0]):  # encoder_hidden
168
                    G.add_edge(
169
                        "input_bias", f"encoder_hidden_{j}", weight=values.numpy()[j]
170
                    )
171
172
            elif layer == "mu.weight":
173
                for j in range(values.shape[1]):  # encoder hidden
174
                    for i in range(values.shape[0]):  # mu
175
                        G.add_edge(
176
                            f"encoder_hidden_{j}",
177
                            f"mu_{i}",
178
                            weight=values.numpy()[i, j],
179
                        )
180
181
            elif layer == "mu.bias":
182
                for i in range(values.shape[0]):  # encoder_hidden
183
                    G.add_edge("mu_bias", f"mu_{i}", weight=values.numpy()[i])
184
185
            elif layer == "var.weight":
186
                for j in range(values.shape[1]):  # encoder hidden
187
                    for i in range(values.shape[0]):  # var
188
                        G.add_edge(
189
                            f"encoder_hidden_{j}",
190
                            f"var_{i}",
191
                            weight=values.numpy()[i, j],
192
                        )
193
194
            elif layer == "var.bias":
195
                for i in range(values.shape[0]):  # encoder_hidden
196
                    G.add_edge("var_bias", f"var_{i}", weight=values.numpy()[i])
197
198
            # Sampled layer from mu and var:
199
            elif layer == "decoderlayers.0.weight":
200
                for i in range(values.shape[1]):  # sampled latent
201
                    for j in range(values.shape[0]):  # decoder_hidden
202
                        G.add_edge(
203
                            f"sam_{i}",
204
                            f"decoder_hidden_{j}",
205
                            weight=values.numpy()[j, i],
206
                        )
207
208
            # Sampled layer from mu and var:
209
            elif layer == "decoderlayers.0.bias":
210
                for j in range(values.shape[0]):  # decoder_hidden
211
                    G.add_edge(
212
                        "sam_bias", f"decoder_hidden_{j}", weight=values.numpy()[j]
213
                    )
214
215
            elif layer == "out.weight":
216
                for j in range(values.shape[1]):  # decoder_hidden
217
                    for k in range(values.shape[0]):  # output
218
                        G.add_edge(
219
                            f"output_{k}",
220
                            f"decoder_hidden_{j}",
221
                            weight=values.numpy()[k, j],
222
                        )
223
224
            elif layer == "out.bias":
225
                for k in range(values.shape[0]):  # output
226
                    G.add_edge("out_bias", f"output_{k}", weight=values.numpy()[k])
227
228
    fig = plt.figure(figsize=(60, 60))
229
    pos = nx.get_node_attributes(G, "pos")
230
    color = list(nx.get_node_attributes(G, "color").values())
231
    edge_color = list(nx.get_edge_attributes(G, "weight").values())
232
    edge_width = list(nx.get_edge_attributes(G, "weight").values())
233
234
    edge_cmap = matplotlib.colormaps["seismic"]
235
    node_cmap = matplotlib.colormaps["seismic"]
236
237
    abs_max = np.max([abs(np.min(color)), abs(np.max(color))])
238
    abs_max_edge = np.max([abs(np.min(edge_color)), abs(np.max(edge_color))])
239
240
    _ = cm.ScalarMappable(
241
        cmap=node_cmap, norm=matplotlib.colors.Normalize(vmin=-abs_max, vmax=abs_max)
242
    )
243
    sm_edge = cm.ScalarMappable(
244
        cmap=edge_cmap,
245
        norm=matplotlib.colors.Normalize(vmin=-abs_max_edge, vmax=abs_max_edge),
246
    )
247
248
    nx.draw(
249
        G,
250
        pos=pos,
251
        with_labels=True,
252
        node_size=100,
253
        node_color=color,
254
        edge_color=edge_color,
255
        width=edge_width,
256
        font_color="black",
257
        font_size=10,
258
        edge_cmap=edge_cmap,
259
        cmap=node_cmap,
260
        vmin=-abs_max,
261
        vmax=abs_max,
262
    )
263
264
    # plt.colorbar(sm_node, label="Node value", shrink = .2)
265
    plt.colorbar(sm_edge, label="Edge value", shrink=0.2)
266
    plt.tight_layout()
267
    fig.savefig(savepath / f"{title}.png", format="png", dpi=200)
268
    return fig