a b/experiments/simulations/two_dimensional_experiments.py
1
import torch
2
import numpy as np
3
import matplotlib.pyplot as plt
4
import pandas as pd
5
6
import seaborn as sns
7
import sys
8
from two_dimensional import two_d_gpsa
9
10
sys.path.append("../..")
11
from models.gpsa_vi_lmc import VariationalWarpGP
12
13
sys.path.append("../../data")
14
from simulated.generate_twod_data import generate_twod_data
15
from plotting.callbacks import callback_twod
16
from util import ConvergenceChecker
17
18
## For PASTE
19
import scanpy as sc
20
21
sys.path.append("../../../paste")
22
from src.paste import PASTE, visualization
23
24
25
device = "cuda" if torch.cuda.is_available() else "cpu"
26
27
LATEX_FONTSIZE = 50
28
29
n_spatial_dims = 2
30
n_views = 2
31
# n_outputs = 10
32
m_G = 25
33
m_X_per_view = 25
34
35
MAX_EPOCHS = 2000
36
PRINT_EVERY = 25
37
N_LATENT_GPS = {"expression": 3}
38
39
40
if __name__ == "__main__":
41
    n_outputs_list = [10, 25, 50]
42
    n_repeats = 3
43
44
    error_mat = np.zeros((n_repeats, len(n_outputs_list)))
45
    error_mat_paste = np.zeros((n_repeats, len(n_outputs_list)))
46
47
    for ii in range(n_repeats):
48
        for jj, n_outputs in enumerate(n_outputs_list):
49
            X, Y, G_means, model, err_paste = two_d_gpsa(
50
                n_outputs=n_outputs,
51
                n_epochs=MAX_EPOCHS,
52
                plot_intermediate=False,
53
                warp_kernel_variance=0.5,
54
                n_latent_gps=N_LATENT_GPS,
55
            )
56
57
            error_mat_paste[ii, jj] = err_paste
58
59
            aligned_coords = G_means["expression"].detach().numpy().squeeze()
60
            n_samples_per_view = n_samples_per_view = X.shape[0] // n_views
61
            view1_aligned_coords = aligned_coords[:n_samples_per_view]
62
            view2_aligned_coords = aligned_coords[n_samples_per_view:]
63
            err = np.mean(
64
                np.sum((view1_aligned_coords - view2_aligned_coords) ** 2, axis=1)
65
            )
66
67
            error_mat[ii, jj] = err
68
69
            if ii == 0:
70
71
                import matplotlib
72
73
                font = {"size": LATEX_FONTSIZE}
74
                matplotlib.rc("font", **font)
75
                matplotlib.rcParams["text.usetex"] = True
76
77
                fig = plt.figure(figsize=(10, 10))
78
                data_expression_ax = fig.add_subplot(211, frameon=False)
79
                latent_expression_ax = fig.add_subplot(212, frameon=False)
80
                callback_twod(
81
                    model,
82
                    X,
83
                    Y,
84
                    data_expression_ax=data_expression_ax,
85
                    latent_expression_ax=latent_expression_ax,
86
                    X_aligned=G_means,
87
                )
88
89
                plt.tight_layout()
90
                plt.savefig(
91
                    "../../plots/two_d_experiments/two_d_simulation_noutputs={}.png".format(
92
                        n_outputs
93
                    )
94
                )
95
                # plt.show()
96
                plt.close()
97
                # import ipdb; ipdb.set_trace()
98
99
        import matplotlib
100
101
        font = {"size": 30}
102
        matplotlib.rc("font", **font)
103
        matplotlib.rcParams["text.usetex"] = True
104
        plt.figure(figsize=(7, 5))
105
106
        error_df_gpsa = pd.melt(
107
            pd.DataFrame(error_mat[: ii + 1, :], columns=n_outputs_list)
108
        )
109
        error_df_gpsa["method"] = ["GPSA"] * error_df_gpsa.shape[0]
110
        error_df_paste = pd.melt(
111
            pd.DataFrame(error_mat_paste[: ii + 1, :], columns=n_outputs_list)
112
        )
113
        error_df_paste["method"] = ["PASTE"] * error_df_paste.shape[0]
114
115
        error_df = pd.concat([error_df_gpsa, error_df_paste], axis=0)
116
        error_df.to_csv("./out/error_vary_n_outputs.csv")
117
118
        sns.lineplot(
119
            data=error_df, x="variable", y="value", hue="method", err_style="bars"
120
        )
121
        plt.xlabel("Number of outputs")
122
        plt.ylabel("Alignent error")
123
        plt.tight_layout()
124
        plt.savefig(
125
            "../../plots/two_d_experiments/error_plot_n_outputs.png".format(n_outputs)
126
        )
127
        plt.close()
128
129
    import ipdb
130
131
    ipdb.set_trace()