Switch to side-by-side view

--- a
+++ b/experiments/expression/visium/visium_alignment.py
@@ -0,0 +1,345 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import pandas as pd
+
+import seaborn as sns
+import sys
+from os.path import join as pjoin
+import scanpy as sc
+import anndata
+import time
+
+# sys.path.append("../../..")
+# sys.path.append("../../../data")
+# from warps import apply_gp_warp
+
+from gpsa import VariationalGPSA, matern12_kernel, rbf_kernel
+from gpsa.plotting import callback_twod
+
+from sklearn.gaussian_process import GaussianProcessRegressor
+from sklearn.gaussian_process.kernels import WhiteKernel, RBF
+
+## For PASTE
+import scanpy as sc
+import anndata
+import matplotlib.patches as mpatches
+
+
+sys.path.append("../../../../paste")
+from src.paste import PASTE, visualization
+
+from sklearn.neighbors import NearestNeighbors, KNeighborsRegressor
+from sklearn.metrics import r2_score
+
+
+def scale_spatial_coords(X, max_val=10.0):
+    X = X - X.min(0)
+    X = X / X.max(0)
+    return X * max_val
+
+
+DATA_DIR = "../../../data/visium/mouse_brain"
+N_GENES = 10
+N_SAMPLES = None
+
+n_spatial_dims = 2
+n_views = 2
+m_G = 200
+m_X_per_view = 200
+
+N_LATENT_GPS = {"expression": None}
+
+N_EPOCHS = 5000
+PRINT_EVERY = 50
+
+
+def process_data(adata, n_top_genes=2000):
+    adata.var_names_make_unique()
+    adata.var["mt"] = adata.var_names.str.startswith("MT-")
+    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)
+
+    sc.pp.filter_cells(adata, min_counts=5000)
+    sc.pp.filter_cells(adata, max_counts=35000)
+    # adata = adata[adata.obs["pct_counts_mt"] < 20]
+    sc.pp.filter_genes(adata, min_cells=10)
+
+    sc.pp.normalize_total(adata, inplace=True)
+    sc.pp.log1p(adata)
+    sc.pp.highly_variable_genes(
+        adata, flavor="seurat", n_top_genes=n_top_genes, subset=True
+    )
+    return adata
+
+
+data_slice1 = sc.read_visium(pjoin(DATA_DIR, "sample1"))
+data_slice1 = process_data(data_slice1, n_top_genes=6000)
+
+plt.figure(figsize=(10, 5))
+plt.subplot(121)
+sc.pl.spatial(
+    data_slice1, color=["mt-Co1"], spot_size=150, img_key=None, ax=plt.gca(), show=False
+)
+plt.subplot(122)
+sc.pl.spatial(
+    data_slice1, color=["Camk2a"], spot_size=150, img_key=None, ax=plt.gca(), show=False
+)
+plt.savefig("./out/visium_dophin_genes.png")
+plt.show()
+import ipdb
+
+ipdb.set_trace()
+
+data_slice2 = sc.read_visium(pjoin(DATA_DIR, "sample2"))
+data_slice2 = process_data(data_slice2, n_top_genes=6000)
+
+data = data_slice1.concatenate(data_slice2)
+
+
+shared_gene_names = data.var.gene_ids.index.values
+data_knn = data_slice1[:, shared_gene_names]
+X_knn = data_knn.obsm["spatial"]
+Y_knn = np.array(data_knn.X.todense())  # [:, :1000]
+nbrs = NearestNeighbors(n_neighbors=2).fit(X_knn)
+distances, indices = nbrs.kneighbors(X_knn)
+
+preds = Y_knn[indices[:, 1]]
+r2_vals = r2_score(Y_knn, preds, multioutput="raw_values")
+
+
+# gene_idx_to_keep = np.argsort(-r2_vals)[:N_GENES]
+# r2_vals_to_keep =
+gene_idx_to_keep = np.where(r2_vals > 0.1)[0]
+N_GENES = min(N_GENES, len(gene_idx_to_keep))
+gene_names_to_keep = data_knn.var.gene_ids.index.values[gene_idx_to_keep]
+gene_names_to_keep = gene_names_to_keep[np.argsort(-r2_vals[gene_idx_to_keep])]
+if N_GENES < len(gene_names_to_keep):
+    gene_names_to_keep = gene_names_to_keep[:N_GENES]
+data = data[:, gene_names_to_keep]
+
+
+# for idx in gene_idx_to_keep:
+#     print(r2_vals[idx], flush=True)
+#     sc.pl.spatial(data_knn, img_key=None, color=[data_knn.var.gene_ids.index.values[idx]], spot_size=150)
+
+
+# fig = plt.figure(figsize=(7, 7), facecolor="white", constrained_layout=True)
+# ax1 = fig.add_subplot(111, frameon=False)
+# sc.pl.spatial(
+#     adata=data[data.obs["batch"] == "0"],
+#     img_key=None,
+#     color="total_counts",
+#     spot_size=150,
+#     ax=ax1,
+#     show=False,
+#     alpha=0.3,
+# )
+# sc.pl.spatial(
+#     adata=data[data.obs["batch"] == "1"],
+#     img_key=None,
+#     color="total_counts",
+#     spot_size=150,
+#     ax=ax1,
+#     show=False,
+#     alpha=0.3,
+# )
+# plt.show()
+# import ipdb; ipdb.set_trace()
+
+
+if N_SAMPLES is not None:
+    rand_idx = np.random.choice(
+        np.arange(data_slice1.shape[0]), size=N_SAMPLES, replace=False
+    )
+    data_slice1 = data_slice1[rand_idx]
+    rand_idx = np.random.choice(
+        np.arange(data_slice2.shape[0]), size=N_SAMPLES, replace=False
+    )
+    data_slice2 = data_slice2[rand_idx]
+
+# all_slices = anndata.concat([data_slice1, data_slice2])
+n_samples_list = [data_slice1.shape[0], data_slice2.shape[0]]
+view_idx = [
+    np.arange(data_slice1.shape[0]),
+    np.arange(data_slice1.shape[0], data_slice1.shape[0] + data_slice2.shape[0]),
+]
+
+X1 = data[data.obs.batch == "0"].obsm["spatial"]
+X2 = data[data.obs.batch == "1"].obsm["spatial"]
+Y1 = np.array(data[data.obs.batch == "0"].X.todense())
+Y2 = np.array(data[data.obs.batch == "1"].X.todense())
+
+X1 = scale_spatial_coords(X1)
+X2 = scale_spatial_coords(X2)
+
+Y1 = (Y1 - Y1.mean(0)) / Y1.std(0)
+Y2 = (Y2 - Y2.mean(0)) / Y2.std(0)
+
+X = np.concatenate([X1, X2])
+Y = np.concatenate([Y1, Y2])
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+n_outputs = Y.shape[1]
+
+x = torch.from_numpy(X).float().clone()
+y = torch.from_numpy(Y).float().clone()
+
+
+data_dict = {
+    "expression": {
+        "spatial_coords": x,
+        "outputs": y,
+        "n_samples_list": n_samples_list,
+    }
+}
+
+model = VariationalGPSA(
+    data_dict,
+    n_spatial_dims=n_spatial_dims,
+    m_X_per_view=m_X_per_view,
+    m_G=m_G,
+    data_init=True,
+    minmax_init=False,
+    grid_init=False,
+    n_latent_gps=N_LATENT_GPS,
+    mean_function="identity_fixed",
+    kernel_func_warp=rbf_kernel,
+    kernel_func_data=rbf_kernel,
+    # fixed_warp_kernel_variances=np.ones(n_views) * 1.,
+    # fixed_warp_kernel_lengthscales=np.ones(n_views) * 10,
+    fixed_view_idx=0,
+).to(device)
+
+view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
+
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
+
+
+def train(model, loss_fn, optimizer):
+    model.train()
+
+    # Forward pass
+    G_means, G_samples, F_latent_samples, F_samples = model.forward(
+        X_spatial={"expression": x}, view_idx=view_idx, Ns=Ns, S=5
+    )
+
+    # Compute loss
+    loss = loss_fn(data_dict, F_samples)
+
+    # Compute gradients and take optimizer step
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+    return loss.item(), G_means
+
+
+# Set up figure.
+fig = plt.figure(figsize=(15, 5), facecolor="white", constrained_layout=True)
+data_expression_ax = fig.add_subplot(131, frameon=False)
+latent_expression_ax = fig.add_subplot(132, frameon=False)
+diff_expression_ax = fig.add_subplot(133, frameon=False)
+plt.show(block=False)
+
+
+# gene_idx = np.where(data.var.gene_ids.index.values == "Ptgds")[0]
+gene_idx = 0
+
+pd.DataFrame(view_idx["expression"]).to_csv("./out/view_idx_visium.csv")
+pd.DataFrame(X).to_csv("./out/X_visium.csv")
+pd.DataFrame(Y).to_csv("./out/Y_visium.csv")
+data.write("./out/data_visium.h5")
+
+for t in range(N_EPOCHS):
+    loss, G_means = train(model, model.loss_fn, optimizer)
+    # print(model.warp_kernel_lengthscales)
+    # print(model.warp_kernel_variances)
+    # print("\n")
+
+    if t % PRINT_EVERY == 0:
+        print("Iter: {0:<10} LL {1:1.3e}".format(t, -loss), flush=True)
+        diff_expression_ax.cla()
+
+        callback_twod_aligned_only(
+            model,
+            X,
+            Y,
+            latent_expression_ax1=data_expression_ax,
+            latent_expression_ax2=latent_expression_ax,
+            X_aligned=G_means,
+            gene_idx=gene_idx,
+        )
+
+        curr_aligned_coords = G_means["expression"].detach().numpy()
+
+        # nearestneighbors = KNeighborsRegressor(n_neighbors=5) # weights="distance")
+
+        # nearestneighbors.fit(
+        #     curr_aligned_coords[view_idx["expression"][0]], Y[view_idx["expression"][0]]
+        # )
+        # Y2_smoothed = nearestneighbors.predict(
+        #     curr_aligned_coords[view_idx["expression"][1]]
+        # )
+        X_knn = curr_aligned_coords[view_idx["expression"][0]]
+        Y_knn = Y[view_idx["expression"][0]]
+        nbrs = NearestNeighbors(n_neighbors=2).fit(X_knn)
+        distances, indices = nbrs.kneighbors(
+            curr_aligned_coords[view_idx["expression"][1]]
+        )
+
+        Y2_smoothed = Y_knn[indices[:, 1]]
+        # import ipdb; ipdb.set_trace()
+        r2_val = r2_score(Y[view_idx["expression"][1]], Y2_smoothed)
+        print(r2_val, flush=True)
+
+        Y_diffs = Y[view_idx["expression"][1]] - Y2_smoothed
+
+        # print(np.nanmean(Y_diffs ** 2), flush=True)
+
+        diff_expression_ax.scatter(
+            curr_aligned_coords[view_idx["expression"][1]][:, 0],
+            curr_aligned_coords[view_idx["expression"][1]][:, 1],
+            c=Y_diffs[:, gene_idx].ravel(),
+            cmap="bwr",
+            s=24,
+            marker="H",
+        )
+
+        plt.draw()
+        plt.savefig("./out/visium_aligned_difference_one_gene.png")
+        plt.pause(1 / 60.0)
+
+        pd.DataFrame(curr_aligned_coords).to_csv("./out/aligned_coords_visium.csv")
+
+        # import ipdb; ipdb.set_trace()
+
+
+plt.close()
+
+import matplotlib
+
+font = {"size": 30}
+matplotlib.rc("font", **font)
+matplotlib.rcParams["text.usetex"] = True
+
+fig = plt.figure(figsize=(14, 7))
+data_expression_ax = fig.add_subplot(121, frameon=False)
+latent_expression_ax = fig.add_subplot(122, frameon=False)
+callback_twod(
+    model,
+    X,
+    Y,
+    data_expression_ax=data_expression_ax,
+    latent_expression_ax=latent_expression_ax,
+    X_aligned=G_means,
+)
+latent_expression_ax.set_title("Aligned data, GPSA")
+latent_expression_ax.set_axis_off()
+data_expression_ax.set_axis_off()
+# plt.axis("off")
+
+plt.tight_layout()
+plt.savefig("./out/visium_alignment.png")
+# plt.show()
+plt.close()