Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery
[c23b31]: / supplementary_files / visualize_vae.py

Download this file

98 lines (86 with data), 2.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
This code plots the the trained autoencoder as a graph, where the edges and their color represent the weights.
If you don't know the input size, it is stated in logs/identify_associations as, e.g., 1343 in Model: VAE (1343 ⇄ 720 ⇄ 50).
The rest of the parameters can also be found in the config file for the identify associations task.
Args:
-mp or --models_path: path to model weight files of the format model___.pt
-op or --output_path: path where the png figure will be saved
-in or --n_input: number of input nodes
-hi or --n_hidden: number of hidden nodes
-la or --n_latent: number of latent nodes
-re or --refit: refit number of the model we want to visualize
Returns:
Png figure of the VAE's weights for a given refit of the model.
Example:
python visualize_vae.py -mp /Users/______/Desktop/MOVE/tutorial/interim_data/models \\
-op /Users/______/Desktop/MOVE/tutorial/results/identify_associations/figures \\
-in 1343 \\
-hi 720 \\
-la 50 \\
-re 0
"""
import argparse
from pathlib import Path
from move.visualization.vae_visualization import plot_vae
parser = argparse.ArgumentParser(
description="Plot weights and nodes for a trained autoencoder"
)
parser.add_argument(
"-mp",
"--models_path",
metavar="mp",
type=Path,
required=True,
help="path to model weight files of the format model___.pt",
)
parser.add_argument(
"-op",
"--output_path",
metavar="op",
type=Path,
required=True,
help="path where the png figure will be saved",
)
parser.add_argument(
"-in",
"--n_input",
metavar="i",
type=int,
required=True,
help="number of input nodes",
)
parser.add_argument(
"-hi",
"--n_hidden",
metavar="h",
type=int,
required=True,
help="number of hidden nodes",
)
parser.add_argument(
"-la",
"--n_latent",
metavar="l",
type=int,
required=True,
help="number of latent nodes",
)
parser.add_argument(
"-re",
"--refit",
metavar="r",
type=str,
required=True,
help="refit number of the model we want to visualize",
)
args = parser.parse_args()
plot_vae_base = plot_vae(
args.models_path,
args.output_path,
f"model_{args.n_latent}_{args.refit}.pt",
f"Vae's weights for refit {args.refit}",
num_input=args.n_input,
num_hidden=args.n_hidden,
num_latent=args.n_latent,
plot_edges=True,
)