a b/supplementary_files/visualize_vae.py
1
"""
2
This code plots the the trained autoencoder as a graph, where the edges and their color represent the weights.
3
If you don't know the input size, it is stated in logs/identify_associations as, e.g., 1343 in Model: VAE (1343 ⇄ 720 ⇄ 50).
4
The rest of the parameters can also be found in the config file for the identify associations task.
5
6
Args: 
7
-mp or --models_path: path to model weight files of the format model___.pt
8
-op or --output_path: path where the png figure will be saved
9
-in or --n_input: number of input nodes
10
-hi or --n_hidden: number of hidden nodes
11
-la or --n_latent: number of latent nodes
12
-re or --refit: refit number of the model we want to visualize
13
14
Returns:
15
    Png figure of the VAE's weights for a given refit of the model.
16
17
Example:
18
19
    python visualize_vae.py -mp /Users/______/Desktop/MOVE/tutorial/interim_data/models \\
20
                            -op /Users/______/Desktop/MOVE/tutorial/results/identify_associations/figures \\
21
                            -in 1343 \\
22
                            -hi 720 \\
23
                            -la 50 \\
24
                            -re 0
25
26
"""
27
28
import argparse
29
from pathlib import Path
30
31
from move.visualization.vae_visualization import plot_vae
32
33
parser = argparse.ArgumentParser(
34
    description="Plot weights and nodes for a trained autoencoder"
35
)
36
parser.add_argument(
37
    "-mp",
38
    "--models_path",
39
    metavar="mp",
40
    type=Path,
41
    required=True,
42
    help="path to model weight files of the format model___.pt",
43
)
44
parser.add_argument(
45
    "-op",
46
    "--output_path",
47
    metavar="op",
48
    type=Path,
49
    required=True,
50
    help="path where the png figure will be saved",
51
)
52
parser.add_argument(
53
    "-in",
54
    "--n_input",
55
    metavar="i",
56
    type=int,
57
    required=True,
58
    help="number of input nodes",
59
)
60
parser.add_argument(
61
    "-hi",
62
    "--n_hidden",
63
    metavar="h",
64
    type=int,
65
    required=True,
66
    help="number of hidden nodes",
67
)
68
parser.add_argument(
69
    "-la",
70
    "--n_latent",
71
    metavar="l",
72
    type=int,
73
    required=True,
74
    help="number of latent nodes",
75
)
76
parser.add_argument(
77
    "-re",
78
    "--refit",
79
    metavar="r",
80
    type=str,
81
    required=True,
82
    help="refit number of the model we want to visualize",
83
)
84
85
args = parser.parse_args()
86
87
88
plot_vae_base = plot_vae(
89
    args.models_path,
90
    args.output_path,
91
    f"model_{args.n_latent}_{args.refit}.pt",
92
    f"Vae's weights for refit {args.refit}",
93
    num_input=args.n_input,
94
    num_hidden=args.n_hidden,
95
    num_latent=args.n_latent,
96
    plot_edges=True,
97
)