|
a |
|
b/supplementary_files/visualize_vae.py |
|
|
1 |
""" |
|
|
2 |
This code plots the the trained autoencoder as a graph, where the edges and their color represent the weights. |
|
|
3 |
If you don't know the input size, it is stated in logs/identify_associations as, e.g., 1343 in Model: VAE (1343 ⇄ 720 ⇄ 50). |
|
|
4 |
The rest of the parameters can also be found in the config file for the identify associations task. |
|
|
5 |
|
|
|
6 |
Args: |
|
|
7 |
-mp or --models_path: path to model weight files of the format model___.pt |
|
|
8 |
-op or --output_path: path where the png figure will be saved |
|
|
9 |
-in or --n_input: number of input nodes |
|
|
10 |
-hi or --n_hidden: number of hidden nodes |
|
|
11 |
-la or --n_latent: number of latent nodes |
|
|
12 |
-re or --refit: refit number of the model we want to visualize |
|
|
13 |
|
|
|
14 |
Returns: |
|
|
15 |
Png figure of the VAE's weights for a given refit of the model. |
|
|
16 |
|
|
|
17 |
Example: |
|
|
18 |
|
|
|
19 |
python visualize_vae.py -mp /Users/______/Desktop/MOVE/tutorial/interim_data/models \\ |
|
|
20 |
-op /Users/______/Desktop/MOVE/tutorial/results/identify_associations/figures \\ |
|
|
21 |
-in 1343 \\ |
|
|
22 |
-hi 720 \\ |
|
|
23 |
-la 50 \\ |
|
|
24 |
-re 0 |
|
|
25 |
|
|
|
26 |
""" |
|
|
27 |
|
|
|
28 |
import argparse |
|
|
29 |
from pathlib import Path |
|
|
30 |
|
|
|
31 |
from move.visualization.vae_visualization import plot_vae |
|
|
32 |
|
|
|
33 |
parser = argparse.ArgumentParser( |
|
|
34 |
description="Plot weights and nodes for a trained autoencoder" |
|
|
35 |
) |
|
|
36 |
parser.add_argument( |
|
|
37 |
"-mp", |
|
|
38 |
"--models_path", |
|
|
39 |
metavar="mp", |
|
|
40 |
type=Path, |
|
|
41 |
required=True, |
|
|
42 |
help="path to model weight files of the format model___.pt", |
|
|
43 |
) |
|
|
44 |
parser.add_argument( |
|
|
45 |
"-op", |
|
|
46 |
"--output_path", |
|
|
47 |
metavar="op", |
|
|
48 |
type=Path, |
|
|
49 |
required=True, |
|
|
50 |
help="path where the png figure will be saved", |
|
|
51 |
) |
|
|
52 |
parser.add_argument( |
|
|
53 |
"-in", |
|
|
54 |
"--n_input", |
|
|
55 |
metavar="i", |
|
|
56 |
type=int, |
|
|
57 |
required=True, |
|
|
58 |
help="number of input nodes", |
|
|
59 |
) |
|
|
60 |
parser.add_argument( |
|
|
61 |
"-hi", |
|
|
62 |
"--n_hidden", |
|
|
63 |
metavar="h", |
|
|
64 |
type=int, |
|
|
65 |
required=True, |
|
|
66 |
help="number of hidden nodes", |
|
|
67 |
) |
|
|
68 |
parser.add_argument( |
|
|
69 |
"-la", |
|
|
70 |
"--n_latent", |
|
|
71 |
metavar="l", |
|
|
72 |
type=int, |
|
|
73 |
required=True, |
|
|
74 |
help="number of latent nodes", |
|
|
75 |
) |
|
|
76 |
parser.add_argument( |
|
|
77 |
"-re", |
|
|
78 |
"--refit", |
|
|
79 |
metavar="r", |
|
|
80 |
type=str, |
|
|
81 |
required=True, |
|
|
82 |
help="refit number of the model we want to visualize", |
|
|
83 |
) |
|
|
84 |
|
|
|
85 |
args = parser.parse_args() |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
plot_vae_base = plot_vae( |
|
|
89 |
args.models_path, |
|
|
90 |
args.output_path, |
|
|
91 |
f"model_{args.n_latent}_{args.refit}.pt", |
|
|
92 |
f"Vae's weights for refit {args.refit}", |
|
|
93 |
num_input=args.n_input, |
|
|
94 |
num_hidden=args.n_hidden, |
|
|
95 |
num_latent=args.n_latent, |
|
|
96 |
plot_edges=True, |
|
|
97 |
) |