Switch to side-by-side view

--- a
+++ b/experiments/simulations/warp_parameter_demo.py
@@ -0,0 +1,90 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.gaussian_process.kernels import RBF
+from scipy.stats import multivariate_normal as mvn
+
+import seaborn as sns
+import sys
+
+sys.path.append("../../data")
+from simulated.generate_twod_data import generate_twod_data
+import os
+from os.path import join as pjoin
+import anndata
+
+import matplotlib
+
+font = {"size": 25}
+matplotlib.rc("font", **font)
+matplotlib.rcParams["text.usetex"] = True
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+LATEX_FONTSIZE = 35
+
+n_spatial_dims = 2
+n_views = 2
+markers = ["o", "X"]
+
+lengthscale_list = [10**x for x in [-1, 0, 1]]
+amplitude_list = [0.1, 1.0, 5.0]
+
+xlimits = [0, 10]
+ylimits = [0, 10]
+grid_size = 10
+x1s = np.linspace(*xlimits, num=grid_size)
+x2s = np.linspace(*ylimits, num=grid_size)
+X1, X2 = np.meshgrid(x1s, x2s)
+X = np.vstack([X1.ravel(), X2.ravel()]).T
+n = len(X)
+
+## Lengthscales
+plt.figure(figsize=(17, 15))
+for ii, lengthscale in enumerate(lengthscale_list):
+    for jj, amplitude in enumerate(amplitude_list):
+        kernel = RBF(length_scale=lengthscale)
+        K = amplitude * kernel(X) + 1e-8 * np.eye(n)
+        X_warped = np.zeros((n, n_spatial_dims))
+        for dd in range(n_spatial_dims):
+            X_warped[:, dd] = mvn.rvs(mean=X[:, dd], cov=K)
+
+        plt.subplot(
+            len(amplitude_list),
+            len(lengthscale_list),
+            ii * len(lengthscale_list) + jj + 1,
+        )
+        plt.scatter(
+            X[:, 0],
+            X[:, 1],
+            color="gray",
+            marker=markers[0],
+            label="Original" if (jj == len(amplitude_list) - 1) and (ii == 0) else None,
+            s=50,
+        )
+        plt.scatter(
+            X_warped[:, 0],
+            X_warped[:, 1],
+            color="red",
+            marker=markers[1],
+            label="Warped" if (jj == len(amplitude_list) - 1) and (ii == 0) else None,
+            s=50,
+        )
+        plt.title(
+            r"$\ell^2$ = "
+            + str(round(lengthscale, 2))
+            + ", $\sigma^2$ = "
+            + str(round(amplitude, 2))
+        )
+
+        if (jj == len(amplitude_list) - 1) and (ii == 0):
+            plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
+        plt.xticks([])
+        plt.yticks([])
+plt.tight_layout()
+plt.savefig("./out/warp_parameter_demo.png")
+plt.show()
+import ipdb
+
+ipdb.set_trace()