--- a +++ b/supplementary_files/3D_latent_visualization.py @@ -0,0 +1,154 @@ +""" +This code is meant to 1) compress the latent representation of the data contained in latent_location.npy +with shape (n_samples,n_features,n_perturbations) to 3 UMAP dimensions. +Then, the location of the samples in 3D-UMAP is plotted, color coding by a feature of interest. The movement +of the samples when perturbing said figure is shown using the same UMAP projection as the baseline. + +Args: + -lp or --latent_path: path to latent numpy array of shape (n_samples,n_features,n_perturbations) + -dp or --data_path: path to original datasets, in the example is interim_path + -ds or --dataset: name of the perturbed dataset + -foi or --feature_of_interest: feature that we want to perturb and visualize + +Returns: + figure folder inside latent_path/ with figures and gifs depicting the latent space distribution of + the feature of interest (perturbed_feature.gif) and the movement that samples undergo when perturbing + said feature (arrows.gif) + +Example: + args.latent_path = Path("/Users/_____/Desktop/MOVE/tutorial/results/identify_associations") + args.data_path = Path("/Users/_____/Desktop/MOVE/tutorial/interim_data") + args.dataset = "ibd.mbx" + args.feature_of_interest = "C20 carnitine" + +How to run: + Example: + 1) go to the folder where this file is located: + cd /Users/____/Desktop/MOVE/supplementary_files + 2) type the following substituting the fields for your files + python 3D_latent_visualization.py -lp /Users/____/Desktop/MOVE/tutorial/results/identify_associations \\ + -dp /Users/____/Desktop/MOVE/tutorial/interim_data \\ + -ds ibd.mbx \\ + -foi="C20 carnitine" + +Note: UMAP must be installed, which can be done by running: + pip install umap-learn +""" + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import umap +from PIL import Image + +from move.visualization.latent_space import plot_3D_latent_and_displacement + +parser = argparse.ArgumentParser( + description="Read latent space matrix file to plot it in 3D" +) +parser.add_argument( + "-lp", + "--latent_path", + metavar="lp", + type=Path, + required=True, + help="path to latent numpy array (n_samples,n_features,n_perturbations)", +) +parser.add_argument( + "-dp", + "--data_path", + metavar="dp", + type=Path, + required=True, + help="path to original datasets, interim_path", +) +parser.add_argument( + "-ds", + "--dataset", + metavar="ds", + type=str, + required=True, + help="name of the perturbed dataset", +) +parser.add_argument( + "-foi", + "--feature_of_interest", + metavar="foi", + type=str, + required=True, + help="feature that we want to perturb", +) +args = parser.parse_args() + +figure_path = Path(args.latent_path / "figures") +figure_path.mkdir(exist_ok=True, parents=True) + +perturbed_dataset = np.load(args.data_path / f"{args.dataset}.npy") +perturbed_features = list(np.load(args.latent_path / "perturbed_features_list.npy")) + +latent_matrix = np.load(args.latent_path / "latent_location.npy") +trans = umap.UMAP(random_state=42, n_components=3).fit(latent_matrix[:, :, -1]) +embedding = trans.embedding_ + +if args.feature_of_interest not in perturbed_features: + raise ValueError(" Feature of interest not in perturbed dataset") + +i = perturbed_features.index(args.feature_of_interest) + +new_embedding = trans.transform(latent_matrix[:, :, i]) + +# # Plot latent space: +pic_num = 0 +n_pictures = 100 +for azimuth, altitude in zip( + np.linspace(-45, 45, n_pictures), np.linspace(15, 45, n_pictures) +): + fig = plot_3D_latent_and_displacement( + embedding, + new_embedding, + feature_values=perturbed_dataset[:, i], + feature_name=f"Sample movement", + show_baseline=False, + show_perturbed=False, + show_arrows=True, + step=1, + altitude=altitude, + azimuth=azimuth, + ) + + fig.savefig(figure_path / f"3D_latent_movement_{pic_num}_arrows.png", dpi=100) + plt.close(fig) + + fig = plot_3D_latent_and_displacement( + embedding, + new_embedding, + feature_values=perturbed_dataset[:, i], + feature_name=f"Feature {args.feature_of_interest}", + show_baseline=True, + show_perturbed=False, + show_arrows=False, + altitude=altitude, + azimuth=azimuth, + ) + fig.savefig( + figure_path / f"3D_latent_movement_{pic_num}_perturbed_feature.png", dpi=100 + ) + plt.close(fig) + pic_num += 1 + +for plot_type in ["arrows", "perturbed_feature"]: + frames = [ + Image.open(figure_path / f"3D_latent_movement_{pic_num}_{plot_type}.png") + for pic_num in range(n_pictures) + ] # sorted(glob.glob("*3D_latent*"))] + frames[0].save( + figure_path / f"{plot_type}.gif", + format="GIF", + append_images=frames[1:], + save_all=True, + duration=75, + loop=0, + )