--- a +++ b/experiments/simulations/warp_parameter_demo.py @@ -0,0 +1,90 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.gaussian_process.kernels import RBF +from scipy.stats import multivariate_normal as mvn + +import seaborn as sns +import sys + +sys.path.append("../../data") +from simulated.generate_twod_data import generate_twod_data +import os +from os.path import join as pjoin +import anndata + +import matplotlib + +font = {"size": 25} +matplotlib.rc("font", **font) +matplotlib.rcParams["text.usetex"] = True + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +LATEX_FONTSIZE = 35 + +n_spatial_dims = 2 +n_views = 2 +markers = ["o", "X"] + +lengthscale_list = [10**x for x in [-1, 0, 1]] +amplitude_list = [0.1, 1.0, 5.0] + +xlimits = [0, 10] +ylimits = [0, 10] +grid_size = 10 +x1s = np.linspace(*xlimits, num=grid_size) +x2s = np.linspace(*ylimits, num=grid_size) +X1, X2 = np.meshgrid(x1s, x2s) +X = np.vstack([X1.ravel(), X2.ravel()]).T +n = len(X) + +## Lengthscales +plt.figure(figsize=(17, 15)) +for ii, lengthscale in enumerate(lengthscale_list): + for jj, amplitude in enumerate(amplitude_list): + kernel = RBF(length_scale=lengthscale) + K = amplitude * kernel(X) + 1e-8 * np.eye(n) + X_warped = np.zeros((n, n_spatial_dims)) + for dd in range(n_spatial_dims): + X_warped[:, dd] = mvn.rvs(mean=X[:, dd], cov=K) + + plt.subplot( + len(amplitude_list), + len(lengthscale_list), + ii * len(lengthscale_list) + jj + 1, + ) + plt.scatter( + X[:, 0], + X[:, 1], + color="gray", + marker=markers[0], + label="Original" if (jj == len(amplitude_list) - 1) and (ii == 0) else None, + s=50, + ) + plt.scatter( + X_warped[:, 0], + X_warped[:, 1], + color="red", + marker=markers[1], + label="Warped" if (jj == len(amplitude_list) - 1) and (ii == 0) else None, + s=50, + ) + plt.title( + r"$\ell^2$ = " + + str(round(lengthscale, 2)) + + ", $\sigma^2$ = " + + str(round(amplitude, 2)) + ) + + if (jj == len(amplitude_list) - 1) and (ii == 0): + plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + plt.xticks([]) + plt.yticks([]) +plt.tight_layout() +plt.savefig("./out/warp_parameter_demo.png") +plt.show() +import ipdb + +ipdb.set_trace()