Diff of /examples/grid_example.py [000000] .. [5c09f6]

Switch to unified view

a b/examples/grid_example.py
1
import torch
2
import numpy as np
3
import matplotlib.pyplot as plt
4
import seaborn as sns
5
import anndata
6
7
from gpsa import VariationalGPSA
8
from gpsa import matern12_kernel, rbf_kernel
9
from gpsa.plotting import callback_twod
10
11
device = "cuda" if torch.cuda.is_available() else "cpu"
12
13
N_SPATIAL_DIMS = 2
14
N_VIEWS = 2
15
M_G = 25
16
M_X_PER_VIEW = 25
17
N_OUTPUTS = 5
18
FIXED_VIEW_IDX = 0
19
N_LATENT_GPS = {"expression": None}
20
21
N_EPOCHS = 3000
22
PRINT_EVERY = 100
23
24
25
data = anndata.read_h5ad("./synthetic_data.h5ad")
26
X = data.obsm["spatial"]
27
Y = data.X
28
view_idx = [np.where(data.obs.batch.values == ii)[0] for ii in range(2)]
29
n_samples_list = [len(x) for x in view_idx]
30
31
x = torch.from_numpy(X).float().clone().to(device)
32
y = torch.from_numpy(Y).float().clone().to(device)
33
34
data_dict = {
35
    "expression": {
36
        "spatial_coords": x,
37
        "outputs": y,
38
        "n_samples_list": n_samples_list,
39
    }
40
}
41
42
model = VariationalGPSA(
43
    data_dict,
44
    n_spatial_dims=N_SPATIAL_DIMS,
45
    m_X_per_view=M_X_PER_VIEW,
46
    m_G=M_G,
47
    data_init=True,
48
    minmax_init=False,
49
    grid_init=False,
50
    n_latent_gps=N_LATENT_GPS,
51
    mean_function="identity_fixed",
52
    kernel_func_warp=rbf_kernel,
53
    kernel_func_data=rbf_kernel,
54
    fixed_view_idx=FIXED_VIEW_IDX,
55
).to(device)
56
57
view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
58
59
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
60
61
62
def train(model, loss_fn, optimizer):
63
    model.train()
64
65
    # Forward pass
66
    G_means, G_samples, F_latent_samples, F_samples = model.forward(
67
        {"expression": x}, view_idx=view_idx, Ns=Ns, S=5
68
    )
69
70
    # Compute loss
71
    loss = loss_fn(data_dict, F_samples)
72
73
    # Compute gradients and take optimizer step
74
    optimizer.zero_grad()
75
    loss.backward()
76
    optimizer.step()
77
78
    return loss.item()
79
80
81
# Set up figure.
82
fig = plt.figure(figsize=(14, 7), facecolor="white", constrained_layout=True)
83
data_expression_ax = fig.add_subplot(121, frameon=False)
84
latent_expression_ax = fig.add_subplot(122, frameon=False)
85
plt.show(block=False)
86
87
for t in range(N_EPOCHS):
88
    loss = train(model, model.loss_fn, optimizer)
89
90
    if t % PRINT_EVERY == 0:
91
        print("Iter: {0:<10} LL {1:1.3e}".format(t, -loss))
92
        G_means, _, _, _ = model.forward({"expression": x}, view_idx=view_idx, Ns=Ns)
93
94
        callback_twod(
95
            model,
96
            X,
97
            Y,
98
            data_expression_ax=data_expression_ax,
99
            latent_expression_ax=latent_expression_ax,
100
            X_aligned=G_means,
101
            s=600,
102
        )
103
        plt.draw()
104
        plt.pause(1 / 60.0)
105
106
print("Done!")
107
108
plt.close()