Diff of /examples/grid_example.py [000000] .. [5c09f6]

Switch to side-by-side view

--- a
+++ b/examples/grid_example.py
@@ -0,0 +1,108 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import anndata
+
+from gpsa import VariationalGPSA
+from gpsa import matern12_kernel, rbf_kernel
+from gpsa.plotting import callback_twod
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+N_SPATIAL_DIMS = 2
+N_VIEWS = 2
+M_G = 25
+M_X_PER_VIEW = 25
+N_OUTPUTS = 5
+FIXED_VIEW_IDX = 0
+N_LATENT_GPS = {"expression": None}
+
+N_EPOCHS = 3000
+PRINT_EVERY = 100
+
+
+data = anndata.read_h5ad("./synthetic_data.h5ad")
+X = data.obsm["spatial"]
+Y = data.X
+view_idx = [np.where(data.obs.batch.values == ii)[0] for ii in range(2)]
+n_samples_list = [len(x) for x in view_idx]
+
+x = torch.from_numpy(X).float().clone().to(device)
+y = torch.from_numpy(Y).float().clone().to(device)
+
+data_dict = {
+    "expression": {
+        "spatial_coords": x,
+        "outputs": y,
+        "n_samples_list": n_samples_list,
+    }
+}
+
+model = VariationalGPSA(
+    data_dict,
+    n_spatial_dims=N_SPATIAL_DIMS,
+    m_X_per_view=M_X_PER_VIEW,
+    m_G=M_G,
+    data_init=True,
+    minmax_init=False,
+    grid_init=False,
+    n_latent_gps=N_LATENT_GPS,
+    mean_function="identity_fixed",
+    kernel_func_warp=rbf_kernel,
+    kernel_func_data=rbf_kernel,
+    fixed_view_idx=FIXED_VIEW_IDX,
+).to(device)
+
+view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
+
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
+
+
+def train(model, loss_fn, optimizer):
+    model.train()
+
+    # Forward pass
+    G_means, G_samples, F_latent_samples, F_samples = model.forward(
+        {"expression": x}, view_idx=view_idx, Ns=Ns, S=5
+    )
+
+    # Compute loss
+    loss = loss_fn(data_dict, F_samples)
+
+    # Compute gradients and take optimizer step
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+    return loss.item()
+
+
+# Set up figure.
+fig = plt.figure(figsize=(14, 7), facecolor="white", constrained_layout=True)
+data_expression_ax = fig.add_subplot(121, frameon=False)
+latent_expression_ax = fig.add_subplot(122, frameon=False)
+plt.show(block=False)
+
+for t in range(N_EPOCHS):
+    loss = train(model, model.loss_fn, optimizer)
+
+    if t % PRINT_EVERY == 0:
+        print("Iter: {0:<10} LL {1:1.3e}".format(t, -loss))
+        G_means, _, _, _ = model.forward({"expression": x}, view_idx=view_idx, Ns=Ns)
+
+        callback_twod(
+            model,
+            X,
+            Y,
+            data_expression_ax=data_expression_ax,
+            latent_expression_ax=latent_expression_ax,
+            X_aligned=G_means,
+            s=600,
+        )
+        plt.draw()
+        plt.pause(1 / 60.0)
+
+print("Done!")
+
+plt.close()