a b/supplementary_files/3D_latent_visualization.py
1
"""
2
This code is meant to 1) compress the latent representation of the data contained in latent_location.npy 
3
with shape (n_samples,n_features,n_perturbations) to 3 UMAP dimensions.
4
Then, the location of the samples in 3D-UMAP is plotted, color coding by a feature of interest. The movement
5
of the samples when perturbing said figure is shown using the same UMAP projection as the baseline.
6
7
Args:
8
    -lp or --latent_path: path to latent numpy array of shape (n_samples,n_features,n_perturbations)
9
    -dp or --data_path: path to original datasets, in the example is interim_path
10
    -ds or --dataset: name of the perturbed dataset
11
    -foi or --feature_of_interest: feature that we want to perturb and visualize
12
13
Returns:
14
    figure folder inside latent_path/ with figures and gifs depicting the latent space distribution of
15
    the feature of interest (perturbed_feature.gif) and the movement that samples undergo when perturbing 
16
    said feature (arrows.gif)
17
18
Example: 
19
    args.latent_path = Path("/Users/_____/Desktop/MOVE/tutorial/results/identify_associations")
20
    args.data_path = Path("/Users/_____/Desktop/MOVE/tutorial/interim_data")
21
    args.dataset = "ibd.mbx"
22
    args.feature_of_interest = "C20 carnitine"
23
24
How to run:
25
    Example:
26
    1) go to the folder where this file is located:
27
    cd /Users/____/Desktop/MOVE/supplementary_files
28
    2) type the following substituting the fields for your files
29
    python 3D_latent_visualization.py -lp /Users/____/Desktop/MOVE/tutorial/results/identify_associations \\
30
                                      -dp /Users/____/Desktop/MOVE/tutorial/interim_data \\
31
                                      -ds ibd.mbx \\
32
                                      -foi="C20 carnitine"  
33
                                      
34
Note: UMAP must be installed, which can be done by running:
35
    pip install umap-learn
36
"""
37
38
import argparse
39
from pathlib import Path
40
41
import matplotlib.pyplot as plt
42
import numpy as np
43
import pandas as pd
44
import umap
45
from PIL import Image
46
47
from move.visualization.latent_space import plot_3D_latent_and_displacement
48
49
parser = argparse.ArgumentParser(
50
    description="Read latent space matrix file to plot it in 3D"
51
)
52
parser.add_argument(
53
    "-lp",
54
    "--latent_path",
55
    metavar="lp",
56
    type=Path,
57
    required=True,
58
    help="path to latent numpy array (n_samples,n_features,n_perturbations)",
59
)
60
parser.add_argument(
61
    "-dp",
62
    "--data_path",
63
    metavar="dp",
64
    type=Path,
65
    required=True,
66
    help="path to original datasets, interim_path",
67
)
68
parser.add_argument(
69
    "-ds",
70
    "--dataset",
71
    metavar="ds",
72
    type=str,
73
    required=True,
74
    help="name of the perturbed dataset",
75
)
76
parser.add_argument(
77
    "-foi",
78
    "--feature_of_interest",
79
    metavar="foi",
80
    type=str,
81
    required=True,
82
    help="feature that we want to perturb",
83
)
84
args = parser.parse_args()
85
86
figure_path = Path(args.latent_path / "figures")
87
figure_path.mkdir(exist_ok=True, parents=True)
88
89
perturbed_dataset = np.load(args.data_path / f"{args.dataset}.npy")
90
perturbed_features = list(np.load(args.latent_path / "perturbed_features_list.npy"))
91
92
latent_matrix = np.load(args.latent_path / "latent_location.npy")
93
trans = umap.UMAP(random_state=42, n_components=3).fit(latent_matrix[:, :, -1])
94
embedding = trans.embedding_
95
96
if args.feature_of_interest not in perturbed_features:
97
    raise ValueError(" Feature of interest not in perturbed dataset")
98
99
i = perturbed_features.index(args.feature_of_interest)
100
101
new_embedding = trans.transform(latent_matrix[:, :, i])
102
103
# # Plot latent space:
104
pic_num = 0
105
n_pictures = 100
106
for azimuth, altitude in zip(
107
    np.linspace(-45, 45, n_pictures), np.linspace(15, 45, n_pictures)
108
):
109
    fig = plot_3D_latent_and_displacement(
110
        embedding,
111
        new_embedding,
112
        feature_values=perturbed_dataset[:, i],
113
        feature_name=f"Sample movement",
114
        show_baseline=False,
115
        show_perturbed=False,
116
        show_arrows=True,
117
        step=1,
118
        altitude=altitude,
119
        azimuth=azimuth,
120
    )
121
122
    fig.savefig(figure_path / f"3D_latent_movement_{pic_num}_arrows.png", dpi=100)
123
    plt.close(fig)
124
125
    fig = plot_3D_latent_and_displacement(
126
        embedding,
127
        new_embedding,
128
        feature_values=perturbed_dataset[:, i],
129
        feature_name=f"Feature {args.feature_of_interest}",
130
        show_baseline=True,
131
        show_perturbed=False,
132
        show_arrows=False,
133
        altitude=altitude,
134
        azimuth=azimuth,
135
    )
136
    fig.savefig(
137
        figure_path / f"3D_latent_movement_{pic_num}_perturbed_feature.png", dpi=100
138
    )
139
    plt.close(fig)
140
    pic_num += 1
141
142
for plot_type in ["arrows", "perturbed_feature"]:
143
    frames = [
144
        Image.open(figure_path / f"3D_latent_movement_{pic_num}_{plot_type}.png")
145
        for pic_num in range(n_pictures)
146
    ]  # sorted(glob.glob("*3D_latent*"))]
147
    frames[0].save(
148
        figure_path / f"{plot_type}.gif",
149
        format="GIF",
150
        append_images=frames[1:],
151
        save_all=True,
152
        duration=75,
153
        loop=0,
154
    )