|
a |
|
b/src/move/visualization/vae_visualization.py |
|
|
1 |
__all__ = ["plot_vae"] |
|
|
2 |
|
|
|
3 |
from pathlib import Path |
|
|
4 |
from typing import Optional |
|
|
5 |
|
|
|
6 |
import matplotlib |
|
|
7 |
import matplotlib.cm as cm |
|
|
8 |
import matplotlib.figure |
|
|
9 |
import matplotlib.pyplot as plt |
|
|
10 |
import networkx as nx |
|
|
11 |
import numpy as np |
|
|
12 |
import torch |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def plot_vae( |
|
|
16 |
path: Path, |
|
|
17 |
savepath: Path, |
|
|
18 |
filename: str, |
|
|
19 |
title: str, |
|
|
20 |
num_input: int, |
|
|
21 |
num_hidden: int, |
|
|
22 |
num_latent: int, |
|
|
23 |
plot_edges=True, |
|
|
24 |
input_sample: Optional[torch.Tensor] = None, |
|
|
25 |
output_sample: Optional[torch.Tensor] = None, |
|
|
26 |
mu: Optional[torch.Tensor] = None, |
|
|
27 |
logvar: Optional[torch.Tensor] = None, |
|
|
28 |
) -> matplotlib.figure.Figure: |
|
|
29 |
""" |
|
|
30 |
This function is aimed to visualize MOVE's architecture. |
|
|
31 |
|
|
|
32 |
Args: |
|
|
33 |
path: path where the trained model with defined weights is to be found |
|
|
34 |
filename: name of the model |
|
|
35 |
title: title of the figure |
|
|
36 |
num_input: number of input nodes |
|
|
37 |
num_hidden: number of output nodes |
|
|
38 |
num_latent: number of latent nodes |
|
|
39 |
plot_edges: plot edges, i.e. connections with assigned weights between nodes |
|
|
40 |
input_sample: array with input values to fill with a mapped color value |
|
|
41 |
output_sample : " " output " " |
|
|
42 |
mu: " " mean (latent) " " |
|
|
43 |
logvar: " " log variance " " |
|
|
44 |
|
|
|
45 |
Returns: |
|
|
46 |
figure |
|
|
47 |
|
|
|
48 |
Notes: |
|
|
49 |
k: input node index |
|
|
50 |
j: hidden node index |
|
|
51 |
i: latent node index |
|
|
52 |
""" |
|
|
53 |
model_weights = torch.load(path / filename) |
|
|
54 |
G = nx.Graph() |
|
|
55 |
|
|
|
56 |
# Position of the layers: |
|
|
57 |
layer_distance = 10 |
|
|
58 |
node_distance = 550 |
|
|
59 |
latent_node_distance = 550 |
|
|
60 |
latent_sep = 5 * latent_node_distance |
|
|
61 |
|
|
|
62 |
# Adding nodes to the graph ############################## |
|
|
63 |
# Bias nodes |
|
|
64 |
G.add_node( |
|
|
65 |
"input_bias", |
|
|
66 |
pos=(-6 * layer_distance, -3 * node_distance - num_input * node_distance / 2), |
|
|
67 |
color=0.0, |
|
|
68 |
) |
|
|
69 |
G.add_node( |
|
|
70 |
"mu_bias", |
|
|
71 |
pos=( |
|
|
72 |
-3 * layer_distance, |
|
|
73 |
(num_hidden + 3) * node_distance |
|
|
74 |
- num_hidden * node_distance / 2 |
|
|
75 |
+ latent_sep / 2, |
|
|
76 |
), |
|
|
77 |
color=0.0, |
|
|
78 |
) |
|
|
79 |
G.add_node( |
|
|
80 |
"var_bias", |
|
|
81 |
pos=( |
|
|
82 |
-3 * layer_distance, |
|
|
83 |
-3 * node_distance - num_hidden * node_distance / 2 - latent_sep / 2, |
|
|
84 |
), |
|
|
85 |
color=0.0, |
|
|
86 |
) |
|
|
87 |
G.add_node( |
|
|
88 |
"sam_bias", |
|
|
89 |
pos=( |
|
|
90 |
0.5 * layer_distance, |
|
|
91 |
-3 * latent_node_distance - num_latent * latent_node_distance / 2, |
|
|
92 |
), |
|
|
93 |
color=0.0, |
|
|
94 |
) |
|
|
95 |
G.add_node( |
|
|
96 |
"out_bias", |
|
|
97 |
pos=(3 * layer_distance, -3 * node_distance - num_hidden * node_distance / 2), |
|
|
98 |
color=0.0, |
|
|
99 |
) |
|
|
100 |
|
|
|
101 |
# Actual nodes |
|
|
102 |
for k in range(num_input): |
|
|
103 |
G.add_node( |
|
|
104 |
f"input_{k}", |
|
|
105 |
pos=( |
|
|
106 |
-6 * layer_distance, |
|
|
107 |
k * node_distance - num_input * node_distance / 2, |
|
|
108 |
), |
|
|
109 |
color=[input_sample[k] if input_sample is not None else 0.0][0], |
|
|
110 |
) |
|
|
111 |
G.add_node( |
|
|
112 |
f"output_{k}", |
|
|
113 |
pos=(6 * layer_distance, k * node_distance - num_input * node_distance / 2), |
|
|
114 |
color=[output_sample[k] if output_sample is not None else 0.0][0], |
|
|
115 |
) |
|
|
116 |
for j in range(num_hidden): |
|
|
117 |
G.add_node( |
|
|
118 |
f"encoder_hidden_{j}", |
|
|
119 |
pos=( |
|
|
120 |
-3 * layer_distance, |
|
|
121 |
j * node_distance - num_hidden * node_distance / 2, |
|
|
122 |
), |
|
|
123 |
color=0.0, |
|
|
124 |
) |
|
|
125 |
G.add_node( |
|
|
126 |
f"decoder_hidden_{j}", |
|
|
127 |
pos=( |
|
|
128 |
3 * layer_distance, |
|
|
129 |
j * node_distance - num_hidden * node_distance / 2, |
|
|
130 |
), |
|
|
131 |
color=0.0, |
|
|
132 |
) |
|
|
133 |
for i in range(num_latent): |
|
|
134 |
G.add_node( |
|
|
135 |
f"mu_{i}", |
|
|
136 |
pos=(0 * layer_distance, i * latent_node_distance + latent_sep / 2), |
|
|
137 |
color=[mu[i] if mu is not None else 0.0][0], |
|
|
138 |
) |
|
|
139 |
G.add_node( |
|
|
140 |
f"var_{i}", |
|
|
141 |
pos=(0 * layer_distance, -i * latent_node_distance - latent_sep / 2), |
|
|
142 |
color=[np.exp(logvar[i] / 2) if logvar is not None else 0.0][0], |
|
|
143 |
) |
|
|
144 |
G.add_node( |
|
|
145 |
f"sam_{i}", |
|
|
146 |
pos=( |
|
|
147 |
0.5 * layer_distance, |
|
|
148 |
i * latent_node_distance - num_latent * latent_node_distance / 2, |
|
|
149 |
), |
|
|
150 |
color=0.0, |
|
|
151 |
) |
|
|
152 |
|
|
|
153 |
# Adding weights to the graph ######################### |
|
|
154 |
|
|
|
155 |
if plot_edges: |
|
|
156 |
for layer, values in model_weights.items(): |
|
|
157 |
if layer == "encoderlayers.0.weight": |
|
|
158 |
for k in range(values.shape[1]): # input |
|
|
159 |
for j in range(values.shape[0]): # encoder_hidden |
|
|
160 |
G.add_edge( |
|
|
161 |
f"input_{k}", |
|
|
162 |
f"encoder_hidden_{j}", |
|
|
163 |
weight=values.numpy()[j, k], |
|
|
164 |
) |
|
|
165 |
|
|
|
166 |
elif layer == "encoderlayers.0.bias": |
|
|
167 |
for j in range(values.shape[0]): # encoder_hidden |
|
|
168 |
G.add_edge( |
|
|
169 |
"input_bias", f"encoder_hidden_{j}", weight=values.numpy()[j] |
|
|
170 |
) |
|
|
171 |
|
|
|
172 |
elif layer == "mu.weight": |
|
|
173 |
for j in range(values.shape[1]): # encoder hidden |
|
|
174 |
for i in range(values.shape[0]): # mu |
|
|
175 |
G.add_edge( |
|
|
176 |
f"encoder_hidden_{j}", |
|
|
177 |
f"mu_{i}", |
|
|
178 |
weight=values.numpy()[i, j], |
|
|
179 |
) |
|
|
180 |
|
|
|
181 |
elif layer == "mu.bias": |
|
|
182 |
for i in range(values.shape[0]): # encoder_hidden |
|
|
183 |
G.add_edge("mu_bias", f"mu_{i}", weight=values.numpy()[i]) |
|
|
184 |
|
|
|
185 |
elif layer == "var.weight": |
|
|
186 |
for j in range(values.shape[1]): # encoder hidden |
|
|
187 |
for i in range(values.shape[0]): # var |
|
|
188 |
G.add_edge( |
|
|
189 |
f"encoder_hidden_{j}", |
|
|
190 |
f"var_{i}", |
|
|
191 |
weight=values.numpy()[i, j], |
|
|
192 |
) |
|
|
193 |
|
|
|
194 |
elif layer == "var.bias": |
|
|
195 |
for i in range(values.shape[0]): # encoder_hidden |
|
|
196 |
G.add_edge("var_bias", f"var_{i}", weight=values.numpy()[i]) |
|
|
197 |
|
|
|
198 |
# Sampled layer from mu and var: |
|
|
199 |
elif layer == "decoderlayers.0.weight": |
|
|
200 |
for i in range(values.shape[1]): # sampled latent |
|
|
201 |
for j in range(values.shape[0]): # decoder_hidden |
|
|
202 |
G.add_edge( |
|
|
203 |
f"sam_{i}", |
|
|
204 |
f"decoder_hidden_{j}", |
|
|
205 |
weight=values.numpy()[j, i], |
|
|
206 |
) |
|
|
207 |
|
|
|
208 |
# Sampled layer from mu and var: |
|
|
209 |
elif layer == "decoderlayers.0.bias": |
|
|
210 |
for j in range(values.shape[0]): # decoder_hidden |
|
|
211 |
G.add_edge( |
|
|
212 |
"sam_bias", f"decoder_hidden_{j}", weight=values.numpy()[j] |
|
|
213 |
) |
|
|
214 |
|
|
|
215 |
elif layer == "out.weight": |
|
|
216 |
for j in range(values.shape[1]): # decoder_hidden |
|
|
217 |
for k in range(values.shape[0]): # output |
|
|
218 |
G.add_edge( |
|
|
219 |
f"output_{k}", |
|
|
220 |
f"decoder_hidden_{j}", |
|
|
221 |
weight=values.numpy()[k, j], |
|
|
222 |
) |
|
|
223 |
|
|
|
224 |
elif layer == "out.bias": |
|
|
225 |
for k in range(values.shape[0]): # output |
|
|
226 |
G.add_edge("out_bias", f"output_{k}", weight=values.numpy()[k]) |
|
|
227 |
|
|
|
228 |
fig = plt.figure(figsize=(60, 60)) |
|
|
229 |
pos = nx.get_node_attributes(G, "pos") |
|
|
230 |
color = list(nx.get_node_attributes(G, "color").values()) |
|
|
231 |
edge_color = list(nx.get_edge_attributes(G, "weight").values()) |
|
|
232 |
edge_width = list(nx.get_edge_attributes(G, "weight").values()) |
|
|
233 |
|
|
|
234 |
edge_cmap = matplotlib.colormaps["seismic"] |
|
|
235 |
node_cmap = matplotlib.colormaps["seismic"] |
|
|
236 |
|
|
|
237 |
abs_max = np.max([abs(np.min(color)), abs(np.max(color))]) |
|
|
238 |
abs_max_edge = np.max([abs(np.min(edge_color)), abs(np.max(edge_color))]) |
|
|
239 |
|
|
|
240 |
_ = cm.ScalarMappable( |
|
|
241 |
cmap=node_cmap, norm=matplotlib.colors.Normalize(vmin=-abs_max, vmax=abs_max) |
|
|
242 |
) |
|
|
243 |
sm_edge = cm.ScalarMappable( |
|
|
244 |
cmap=edge_cmap, |
|
|
245 |
norm=matplotlib.colors.Normalize(vmin=-abs_max_edge, vmax=abs_max_edge), |
|
|
246 |
) |
|
|
247 |
|
|
|
248 |
nx.draw( |
|
|
249 |
G, |
|
|
250 |
pos=pos, |
|
|
251 |
with_labels=True, |
|
|
252 |
node_size=100, |
|
|
253 |
node_color=color, |
|
|
254 |
edge_color=edge_color, |
|
|
255 |
width=edge_width, |
|
|
256 |
font_color="black", |
|
|
257 |
font_size=10, |
|
|
258 |
edge_cmap=edge_cmap, |
|
|
259 |
cmap=node_cmap, |
|
|
260 |
vmin=-abs_max, |
|
|
261 |
vmax=abs_max, |
|
|
262 |
) |
|
|
263 |
|
|
|
264 |
# plt.colorbar(sm_node, label="Node value", shrink = .2) |
|
|
265 |
plt.colorbar(sm_edge, label="Edge value", shrink=0.2) |
|
|
266 |
plt.tight_layout() |
|
|
267 |
fig.savefig(savepath / f"{title}.png", format="png", dpi=200) |
|
|
268 |
return fig |