[54ded2]: / experiments / simulations / warp_parameter_demo.py

Download this file

91 lines (76 with data), 2.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process.kernels import RBF
from scipy.stats import multivariate_normal as mvn
import seaborn as sns
import sys
sys.path.append("../../data")
from simulated.generate_twod_data import generate_twod_data
import os
from os.path import join as pjoin
import anndata
import matplotlib
font = {"size": 25}
matplotlib.rc("font", **font)
matplotlib.rcParams["text.usetex"] = True
device = "cuda" if torch.cuda.is_available() else "cpu"
LATEX_FONTSIZE = 35
n_spatial_dims = 2
n_views = 2
markers = ["o", "X"]
lengthscale_list = [10**x for x in [-1, 0, 1]]
amplitude_list = [0.1, 1.0, 5.0]
xlimits = [0, 10]
ylimits = [0, 10]
grid_size = 10
x1s = np.linspace(*xlimits, num=grid_size)
x2s = np.linspace(*ylimits, num=grid_size)
X1, X2 = np.meshgrid(x1s, x2s)
X = np.vstack([X1.ravel(), X2.ravel()]).T
n = len(X)
## Lengthscales
plt.figure(figsize=(17, 15))
for ii, lengthscale in enumerate(lengthscale_list):
for jj, amplitude in enumerate(amplitude_list):
kernel = RBF(length_scale=lengthscale)
K = amplitude * kernel(X) + 1e-8 * np.eye(n)
X_warped = np.zeros((n, n_spatial_dims))
for dd in range(n_spatial_dims):
X_warped[:, dd] = mvn.rvs(mean=X[:, dd], cov=K)
plt.subplot(
len(amplitude_list),
len(lengthscale_list),
ii * len(lengthscale_list) + jj + 1,
)
plt.scatter(
X[:, 0],
X[:, 1],
color="gray",
marker=markers[0],
label="Original" if (jj == len(amplitude_list) - 1) and (ii == 0) else None,
s=50,
)
plt.scatter(
X_warped[:, 0],
X_warped[:, 1],
color="red",
marker=markers[1],
label="Warped" if (jj == len(amplitude_list) - 1) and (ii == 0) else None,
s=50,
)
plt.title(
r"$\ell^2$ = "
+ str(round(lengthscale, 2))
+ ", $\sigma^2$ = "
+ str(round(amplitude, 2))
)
if (jj == len(amplitude_list) - 1) and (ii == 0):
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig("./out/warp_parameter_demo.png")
plt.show()
import ipdb
ipdb.set_trace()