a b/gpsa/plotting/callbacks.py
1
import torch
2
import numpy as np
3
import matplotlib.pyplot as plt
4
from torch.utils.data import Dataset, DataLoader
5
import time
6
import pandas as pd
7
from scipy.stats import pearsonr
8
9
from matplotlib.lines import Line2D
10
11
import seaborn as sns
12
13
SCATTER_POINT_SIZE = 50
14
15
device = "cuda" if torch.cuda.is_available() else "cpu"
16
17
def callback_oned(
18
    model,
19
    X,
20
    Y,
21
    X_aligned,
22
    data_expression_ax,
23
    latent_expression_ax,
24
    prediction_ax=None,
25
    X_test=None,
26
    Y_pred=None,
27
    Y_test_true=None,
28
    X_test_aligned=None,
29
    F_samples=None,
30
):
31
    model.eval()
32
    markers = list(Line2D.markers.keys())
33
    colors = ["blue", "orange"]
34
35
    if model.fixed_view_idx is not None:
36
        curr_idx = model.view_idx["expression"][model.fixed_view_idx]
37
        X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32))
38
39
    data_expression_ax.cla()
40
    latent_expression_ax.cla()
41
42
    data_expression_ax.set_title("Observed data")
43
    latent_expression_ax.set_title("Aligned data")
44
45
    data_expression_ax.set_xlabel("Spatial coordinate")
46
    latent_expression_ax.set_xlabel("Spatial coordinate")
47
48
    data_expression_ax.set_ylabel("Outcome")
49
    latent_expression_ax.set_ylabel("Outcome")
50
51
    data_expression_ax.set_xlim([X.min(), X.max()])
52
    latent_expression_ax.set_xlim([X.min(), X.max()])
53
54
    for vv in range(model.n_views):
55
56
        view_idx = model.view_idx["expression"]
57
58
        data_expression_ax.scatter(
59
            X[view_idx[vv], 0],
60
            Y[view_idx[vv], 0],
61
            label="View {}".format(vv + 1),
62
            marker=markers[vv],
63
            s=SCATTER_POINT_SIZE,
64
            c="blue",
65
        )
66
        if Y.shape[1] > 1:
67
            data_expression_ax.scatter(
68
                X[view_idx[vv], 0],
69
                Y[view_idx[vv], 1],
70
                label="View {}".format(vv + 1),
71
                marker=markers[vv],
72
                s=SCATTER_POINT_SIZE,
73
                c="orange",
74
            )
75
        latent_expression_ax.scatter(
76
            # model.G_means["expression"].detach().cpu().numpy()[view_idx[vv], 0],
77
            X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0],
78
            Y[view_idx[vv], 0],
79
            c="blue",
80
            label="View {}".format(vv + 1),
81
            marker=markers[vv],
82
            s=SCATTER_POINT_SIZE,
83
        )
84
        if Y.shape[1] > 1:
85
            latent_expression_ax.scatter(
86
                # model.G_means["expression"].detach().cpu().numpy()[view_idx[vv], 0],
87
                X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0],
88
                Y[view_idx[vv], 1],
89
                c="orange",
90
                label="View {}".format(vv + 1),
91
                marker=markers[vv],
92
                s=SCATTER_POINT_SIZE,
93
            )
94
        # latent_expression_ax.scatter(
95
        #   model.Xtilde.detach().cpu().numpy()[vv, :, 0],
96
        #   model.delta_list.detach().cpu().numpy()[vv][:, 0],
97
        #   c="red",
98
        #   label="View {}".format(vv + 1),
99
        #   marker="^",
100
        #   s=100,
101
        # )
102
103
        if F_samples is not None:
104
            latent_expression_ax.scatter(
105
                X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0],
106
                F_samples.detach().cpu().numpy()[view_idx[vv], 0],
107
                c="red",
108
                marker=markers[vv],
109
                s=SCATTER_POINT_SIZE,
110
            )
111
            if Y.shape[1] > 1:
112
                latent_expression_ax.scatter(
113
                    X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0],
114
                    F_samples.detach().cpu().numpy()[view_idx[vv], 1],
115
                    c="green",
116
                    marker=markers[vv],
117
                    s=SCATTER_POINT_SIZE,
118
                )
119
120
    if prediction_ax is not None:
121
122
        prediction_ax.cla()
123
        prediction_ax.set_title("Predictions")
124
        prediction_ax.set_xlabel("True outcome")
125
        prediction_ax.set_ylabel("Predicted outcome")
126
127
        ### Plots the warping function
128
        # prediction_ax.scatter(
129
        #   X[view_idx[vv], 0],
130
        #   X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0],
131
        #   label="View {}".format(vv + 1),
132
        #   marker=markers[vv],
133
        #   s=100,
134
        #   c="blue",
135
        # )
136
        # prediction_ax.scatter(
137
        #   model.Xtilde.detach().cpu().numpy()[vv, :, 0],
138
        #   model.delta_list.detach().cpu().numpy()[vv][:, 0],
139
        #   c="red",
140
        #   label="View {}".format(vv + 1),
141
        #   marker="^",
142
        #   s=100,
143
        # )
144
        latent_expression_ax.scatter(
145
            X_test_aligned["expression"].detach().cpu().numpy()[:, 0],
146
            Y_pred.detach().cpu().numpy()[:, 0],
147
            c="blue",
148
            label="Prediction",
149
            marker="^",
150
            s=SCATTER_POINT_SIZE,
151
        )
152
        latent_expression_ax.scatter(
153
            X_test_aligned["expression"].detach().cpu().numpy()[:, 0],
154
            Y_pred.detach().cpu().numpy()[:, 1],
155
            c="orange",
156
            label="Prediction",
157
            marker="^",
158
            s=SCATTER_POINT_SIZE,
159
        )
160
        prediction_ax.scatter(
161
            Y_test_true[:, 0],
162
            Y_pred.detach().cpu().numpy()[:, 0],
163
            c="black",
164
            s=SCATTER_POINT_SIZE,
165
        )
166
        prediction_ax.scatter(
167
            Y_test_true[:, 1],
168
            Y_pred.detach().cpu().numpy()[:, 1],
169
            c="black",
170
            s=SCATTER_POINT_SIZE,
171
            marker="^",
172
        )
173
174
    data_expression_ax.legend()
175
    plt.draw()
176
    plt.pause(1 / 60.0)
177
178
179
def callback_twod(
180
    model,
181
    X,
182
    Y,
183
    X_aligned,
184
    data_expression_ax,
185
    latent_expression_ax,
186
    is_mle=False,
187
    gene_idx=0,
188
    s=200,
189
    include_legend=False,
190
):
191
192
    if model.fixed_view_idx is not None:
193
        if is_mle:
194
            pass
195
        else:
196
            curr_idx = model.view_idx["expression"][model.fixed_view_idx]
197
            X_aligned["expression"][curr_idx] = torch.tensor(
198
                X[curr_idx].astype(np.float32), device=device
199
            )
200
201
    model.eval()
202
    markers = [".", "+", "^"]
203
    colors = ["blue", "orange"]
204
205
    data_expression_ax.cla()
206
    latent_expression_ax.cla()
207
    data_expression_ax.set_title("Observed data")
208
    latent_expression_ax.set_title("Aligned data")
209
210
    curr_view_idx = model.view_idx["expression"]
211
212
    latent_Xs = []
213
    Xs = []
214
    Ys = []
215
    markers_list = []
216
    viewname_list = []
217
218
    for vv in range(model.n_views):
219
220
        ## Data
221
        Xs.append(X[curr_view_idx[vv]])
222
223
        ## Latents
224
        curr_latent_Xs = X_aligned["expression"].detach().cpu().numpy()[curr_view_idx[vv]]
225
        latent_Xs.append(curr_latent_Xs)
226
        Ys.append(Y[curr_view_idx[vv], gene_idx])
227
        markers_list.append([markers[vv]] * curr_latent_Xs.shape[0])
228
        viewname_list.append(
229
            ["Observation {}".format(vv + 1)] * curr_latent_Xs.shape[0]
230
        )
231
232
    Xs = np.concatenate(Xs, axis=0)
233
    latent_Xs = np.concatenate(latent_Xs, axis=0)
234
    Ys = np.concatenate(Ys)
235
    markers_list = np.concatenate(markers_list)
236
    viewname_list = np.concatenate(viewname_list)
237
238
    data_df = pd.DataFrame(
239
        {
240
            "X1": Xs[:, 0],
241
            "X2": Xs[:, 1],
242
            "Y": Ys,
243
            "marker": markers_list,
244
            "view": viewname_list,
245
        }
246
    )
247
248
    latent_df = pd.DataFrame(
249
        {
250
            "X1": latent_Xs[:, 0],
251
            "X2": latent_Xs[:, 1],
252
            "Y": Ys,
253
            "marker": markers_list,
254
            "view": viewname_list,
255
        }
256
    )
257
258
    plt.sca(data_expression_ax)
259
    g = sns.scatterplot(
260
        data=data_df,
261
        x="X1",
262
        y="X2",
263
        hue="Y",
264
        style="view",
265
        ax=data_expression_ax,
266
        s=s,
267
        linewidth=1.8,
268
        edgecolor="black",
269
        palette="viridis",
270
    )
271
    if not include_legend:
272
        g.legend_.remove()
273
    # plt.colorbar()
274
    # plt.axis("off")
275
    # plt.scatter(model.Xtilde.detach().cpu().numpy()[0, :, 0], model.Xtilde.detach().cpu().numpy()[0, :, 1], color="red")
276
    # plt.scatter(model.Xtilde.detach().cpu().numpy()[1, :, 0], model.Xtilde.detach().cpu().numpy()[1, :, 1], color="red")
277
    # plt.scatter(model.Gtilde.detach().cpu().numpy()[:, 0], model.Gtilde.detach().cpu().numpy()[:, 1], color="red")
278
    # plt.axis("off")
279
280
    plt.sca(latent_expression_ax)
281
    g = sns.scatterplot(
282
        data=latent_df,
283
        x="X1",
284
        y="X2",
285
        hue="Y",
286
        style="view",
287
        ax=latent_expression_ax,
288
        s=s,
289
        linewidth=1.8,
290
        edgecolor="black",
291
        palette="viridis",
292
    )
293
    if not include_legend:
294
        g.legend_.remove()
295
    # plt.colorbar()
296
297
    # import ipdb; ipdb.set_trace()
298
299
    # for vv in range(model.n_views):
300
301
    #     # import ipdb; ipdb.set_trace()
302
    #     data_expression_ax.scatter(
303
    #         X[curr_view_idx[vv], 0],
304
    #         X[curr_view_idx[vv], 1],
305
    #         c=Y[curr_view_idx[vv], 0],
306
    #         label="View {}".format(vv + 1),
307
    #         marker=markers[vv],
308
    #         s=400,
309
    #     )
310
    # latent_expression_ax.scatter(
311
    #     X_aligned["expression"].detach().cpu().numpy()[curr_view_idx[vv], 0],
312
    #     X_aligned["expression"].detach().cpu().numpy()[curr_view_idx[vv], 1],
313
    #     c=Y[curr_view_idx[vv], 0],
314
    #     label="View {}".format(vv + 1),
315
    #     marker=markers[vv],
316
    #     s=400,
317
    # )
318
    # plt.axis("off")
319
320
321
def callback_twod_aligned_only(
322
    model,
323
    X,
324
    Y,
325
    X_aligned,
326
    latent_expression_ax1,
327
    latent_expression_ax2,
328
    is_mle=False,
329
    gene_idx=0,
330
):
331
332
    if model.fixed_view_idx is not None:
333
        if is_mle:
334
            pass
335
        else:
336
            curr_idx = model.view_idx["expression"][model.fixed_view_idx]
337
            X_aligned["expression"][curr_idx] = torch.tensor(
338
                X[curr_idx].astype(np.float32)
339
            )
340
341
    model.eval()
342
    markers = [".", "+", "^"]
343
    colors = ["blue", "orange"]
344
345
    latent_expression_ax1.cla()
346
    latent_expression_ax2.cla()
347
    latent_expression_ax1.set_title("Observed data")
348
    latent_expression_ax2.set_title("Aligned data")
349
350
    curr_view_idx = model.view_idx["expression"]
351
352
    latent_Xs = []
353
    Xs = []
354
    Ys = []
355
    markers_list = []
356
    viewname_list = []
357
358
    aligned_coords = X_aligned["expression"].detach().cpu().numpy()
359
360
    for vv in range(model.n_views):
361
362
        ## Data
363
        Xs.append(X[curr_view_idx[vv]])
364
365
        ## Latents
366
        curr_latent_Xs = aligned_coords[curr_view_idx[vv]]
367
        latent_Xs.append(curr_latent_Xs)
368
        Ys.append(Y[curr_view_idx[vv], gene_idx])
369
        markers_list.append([markers[vv]] * curr_latent_Xs.shape[0])
370
        viewname_list.append(["View {}".format(vv + 1)] * curr_latent_Xs.shape[0])
371
372
    latent_expression_ax1.scatter(
373
        aligned_coords[curr_view_idx[0]][:, 0],
374
        aligned_coords[curr_view_idx[0]][:, 1],
375
        c=Y[curr_view_idx[0]][:, gene_idx].squeeze(),
376
        s=24,
377
        marker="h",
378
    )
379
    latent_expression_ax2.scatter(
380
        aligned_coords[curr_view_idx[1]][:, 0],
381
        aligned_coords[curr_view_idx[1]][:, 1],
382
        c=Y[curr_view_idx[1]][:, gene_idx].squeeze(),
383
        s=24,
384
        marker="h",
385
    )
386
    # latent_expression_ax1.scatter(model.Xtilde.detach().cpu().numpy()[0, :, 0], model.Xtilde.detach().cpu().numpy()[0, :, 1], color="red")
387
    # latent_expression_ax2.scatter(model.Xtilde.detach().cpu().numpy()[1, :, 0], model.Xtilde.detach().cpu().numpy()[1, :, 1], color="red")
388
389
    plt.axis("off")
390
391
392
def callback_twod_multimodal(
393
    model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100
394
):
395
396
    # if model.fixed_view_idx is not None:
397
    #     if is_mle:
398
    #         pass
399
    #     else:
400
    #         curr_idx = model.view_idx["expression"][model.fixed_view_idx]
401
    #         X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32))
402
403
    model.eval()
404
    markers = [".", "+", "^"]
405
    colors = ["blue", "orange"]
406
407
    [ax.cla() for ax in axes]
408
409
    axes[0].set_title("Observed expression")
410
    axes[1].set_title("Aligned expression")
411
    axes[2].set_title("Observed histology")
412
    axes[3].set_title("Aligned histology")
413
414
    axis_counter = 0
415
    n_mods = 2
416
    for mod in ["expression", "histology"]:
417
        curr_view_idx = model.view_idx[mod]
418
        for vv in range(model.n_views):
419
420
            # import ipdb; ipdb.set_trace()
421
            curr_coords = data_dict[mod]["spatial_coords"]
422
423
            if mod == "histology" and rgb:
424
                curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], :]
425
            else:
426
                curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], 0]
427
            axes[axis_counter].scatter(
428
                curr_coords[curr_view_idx[vv], 0],
429
                curr_coords[curr_view_idx[vv], 1],
430
                c=curr_outputs,
431
                label="View {}".format(vv + 1),
432
                marker=markers[vv],
433
                s=scatterpoint_size,
434
            )
435
            axes[axis_counter + 1].scatter(
436
                X_aligned[mod].detach().cpu().numpy()[curr_view_idx[vv], 0],
437
                X_aligned[mod].detach().cpu().numpy()[curr_view_idx[vv], 1],
438
                c=curr_outputs,
439
                label="View {}".format(vv + 1),
440
                marker=markers[vv],
441
                s=scatterpoint_size,
442
            )
443
        axis_counter += n_mods