a b/experiments/simulations/two_dimensional_animation.py
1
import torch
2
import numpy as np
3
import matplotlib.pyplot as plt
4
5
import seaborn as sns
6
import sys
7
from gpsa import VariationalGPSA, matern12_kernel, rbf_kernel
8
from gpsa.plotting import callback_twod
9
10
sys.path.append("../../data")
11
from simulated.generate_twod_data import generate_twod_data
12
13
import matplotlib.animation as animation
14
import matplotlib.image as mpimg
15
import os
16
from os.path import join as pjoin
17
import anndata
18
19
import matplotlib
20
from matplotlib.lines import Line2D
21
22
23
font = {"size": 20}
24
matplotlib.rc("font", **font)
25
matplotlib.rcParams["text.usetex"] = True
26
27
28
device = "cuda" if torch.cuda.is_available() else "cpu"
29
30
LATEX_FONTSIZE = 35
31
32
n_spatial_dims = 2
33
n_views = 2
34
m_G = 50
35
m_X_per_view = 50
36
37
N_EPOCHS = 2000
38
PRINT_EVERY = 100
39
ONE_SAMPLE_FIXED = True
40
41
42
def two_d_gpsa(
43
    n_outputs,
44
    n_epochs,
45
    n_latent_gps,
46
    warp_kernel_variance=0.1,
47
    noise_variance=0.0,
48
    plot_intermediate=True,
49
    fixed_view_idx=None,
50
):
51
52
    X, Y, n_samples_list, view_idx = generate_twod_data(
53
        n_views,
54
        n_outputs,
55
        grid_size=10,
56
        n_latent_gps=n_latent_gps["expression"],
57
        kernel_lengthscale=5.0,
58
        kernel_variance=warp_kernel_variance,
59
        noise_variance=noise_variance,
60
        fixed_view_idx=0 if ONE_SAMPLE_FIXED else None,
61
    )
62
    n_samples_per_view = X.shape[0] // n_views
63
64
    plt.figure(figsize=(7, 5))
65
    markers = ["o", "X"]
66
    for vv in range(n_views):
67
        plt.scatter(
68
            X[view_idx[vv]][:, 0],
69
            X[view_idx[vv]][:, 1],
70
            c=Y[view_idx[vv]][:, 0],
71
            s=400,
72
            marker=markers[vv],
73
            label="View {}".format(vv + 1),
74
            edgecolor="black",
75
            linewidth=2,
76
        )
77
    plt.xlabel("X1")
78
    plt.ylabel("X2")
79
    plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
80
    plt.tight_layout()
81
    plt.savefig("./../../examples/synthetic_data_example.png")
82
    plt.close()
83
84
    ## Save as anndata object
85
    data_obj = anndata.AnnData(Y)
86
    data_obj.obsm["spatial"] = X
87
    batch_id = np.concatenate([[xx] * n_samples_list[xx] for xx in range(n_views)])
88
    data_obj.obs["batch"] = batch_id
89
    data_obj.write("../../examples/synthetic_data.h5ad")
90
91
    x = torch.from_numpy(X).float().clone()
92
    y = torch.from_numpy(Y).float().clone()
93
94
    data_dict = {
95
        "expression": {
96
            "spatial_coords": x,
97
            "outputs": y,
98
            "n_samples_list": n_samples_list,
99
        }
100
    }
101
102
    model = VariationalGPSA(
103
        data_dict,
104
        n_spatial_dims=n_spatial_dims,
105
        m_X_per_view=m_X_per_view,
106
        m_G=m_G,
107
        data_init=True,
108
        minmax_init=False,
109
        grid_init=False,
110
        n_latent_gps=n_latent_gps,
111
        mean_function="identity_fixed",
112
        kernel_func_warp=rbf_kernel,
113
        kernel_func_data=rbf_kernel,
114
        fixed_view_idx=fixed_view_idx,
115
    ).to(device)
116
117
    view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
118
119
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
120
121
    def train(model, loss_fn, optimizer):
122
        model.train()
123
124
        # Forward pass
125
        G_means, G_samples, F_latent_samples, F_samples = model.forward(
126
            {"expression": x}, view_idx=view_idx, Ns=Ns, S=5
127
        )
128
129
        # Compute loss
130
        loss = loss_fn(data_dict, F_samples)
131
132
        # Compute gradients and take optimizer step
133
        optimizer.zero_grad()
134
        loss.backward()
135
        optimizer.step()
136
137
        return loss.item()
138
139
    # Set up figure.
140
    fig = plt.figure(figsize=(12.14, 5), facecolor="white", constrained_layout=True)
141
    data_expression_ax = fig.add_subplot(121, frameon=False)
142
    latent_expression_ax = fig.add_subplot(122, frameon=False)
143
    plt.show(block=False)
144
145
    loss_trace = []
146
    error_trace = []
147
    n_frames = 0
148
    for t in range(n_epochs):
149
        loss = train(model, model.loss_fn, optimizer)
150
        loss_trace.append(loss)
151
152
        if plot_intermediate and t % PRINT_EVERY == 0:
153
            print("Iter: {0:<10} LL {1:1.3e}".format(t, -loss))
154
            G_means, G_samples, F_latent_samples, F_samples = model.forward(
155
                {"expression": x}, view_idx=view_idx, Ns=Ns
156
            )
157
158
            callback_twod(
159
                model,
160
                X,
161
                Y,
162
                data_expression_ax=data_expression_ax,
163
                latent_expression_ax=latent_expression_ax,
164
                X_aligned=G_means,
165
                s=600,
166
            )
167
            legend_elements = [
168
                Line2D(
169
                    [0],
170
                    [0],
171
                    marker="o",
172
                    color="w",
173
                    label="Slice 1",
174
                    markerfacecolor="black",
175
                    markersize=20,
176
                ),
177
                Line2D(
178
                    [0],
179
                    [0],
180
                    marker="X",
181
                    color="w",
182
                    label="Slice 2",
183
                    markerfacecolor="black",
184
                    markersize=20,
185
                ),
186
            ]
187
188
            # Create the figure
189
            plt.legend(
190
                handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5)
191
            )
192
            plt.tight_layout()
193
            # plt.draw()
194
            plt.savefig("./tmp/tmp{}".format(n_frames))
195
            n_frames += 1
196
            # plt.pause(1 / 60.0)
197
198
        G_means, G_samples, F_latent_samples, F_samples = model.forward(
199
            {"expression": x}, view_idx=view_idx, Ns=Ns
200
        )
201
202
    print("Done!")
203
204
    plt.close()
205
206
    fig = plt.figure()
207
    ims = []
208
    for ii in range(n_frames):
209
        fname = "./tmp/tmp{}.png".format(ii)
210
        img = mpimg.imread(fname)
211
        im = plt.imshow(img)
212
        ax = plt.gca()
213
        ax.set_yticks([])
214
        ax.set_xticks([])
215
        ims.append([im])
216
        os.remove(fname)
217
218
    writervideo = animation.FFMpegWriter(fps=5)
219
    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=500)
220
221
    if ONE_SAMPLE_FIXED:
222
        save_name = "alignment_animation_template.gif"
223
    else:
224
        save_name = "alignment_animation.gif"
225
    ani.save(
226
        pjoin("out", save_name),
227
        writer=writervideo,
228
        dpi=1000,
229
    )
230
    plt.close()
231
232
233
if __name__ == "__main__":
234
235
    n_outputs = 30
236
    two_d_gpsa(
237
        n_epochs=N_EPOCHS,
238
        n_outputs=n_outputs,
239
        warp_kernel_variance=0.5,
240
        noise_variance=0.001,
241
        n_latent_gps={"expression": 5},
242
        fixed_view_idx=0,
243
    )
244
245
    import ipdb
246
247
    ipdb.set_trace()