|
a |
|
b/experiments/simulations/warp_parameter_demo.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
from sklearn.gaussian_process.kernels import RBF |
|
|
5 |
from scipy.stats import multivariate_normal as mvn |
|
|
6 |
|
|
|
7 |
import seaborn as sns |
|
|
8 |
import sys |
|
|
9 |
|
|
|
10 |
sys.path.append("../../data") |
|
|
11 |
from simulated.generate_twod_data import generate_twod_data |
|
|
12 |
import os |
|
|
13 |
from os.path import join as pjoin |
|
|
14 |
import anndata |
|
|
15 |
|
|
|
16 |
import matplotlib |
|
|
17 |
|
|
|
18 |
font = {"size": 25} |
|
|
19 |
matplotlib.rc("font", **font) |
|
|
20 |
matplotlib.rcParams["text.usetex"] = True |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
24 |
|
|
|
25 |
LATEX_FONTSIZE = 35 |
|
|
26 |
|
|
|
27 |
n_spatial_dims = 2 |
|
|
28 |
n_views = 2 |
|
|
29 |
markers = ["o", "X"] |
|
|
30 |
|
|
|
31 |
lengthscale_list = [10**x for x in [-1, 0, 1]] |
|
|
32 |
amplitude_list = [0.1, 1.0, 5.0] |
|
|
33 |
|
|
|
34 |
xlimits = [0, 10] |
|
|
35 |
ylimits = [0, 10] |
|
|
36 |
grid_size = 10 |
|
|
37 |
x1s = np.linspace(*xlimits, num=grid_size) |
|
|
38 |
x2s = np.linspace(*ylimits, num=grid_size) |
|
|
39 |
X1, X2 = np.meshgrid(x1s, x2s) |
|
|
40 |
X = np.vstack([X1.ravel(), X2.ravel()]).T |
|
|
41 |
n = len(X) |
|
|
42 |
|
|
|
43 |
## Lengthscales |
|
|
44 |
plt.figure(figsize=(17, 15)) |
|
|
45 |
for ii, lengthscale in enumerate(lengthscale_list): |
|
|
46 |
for jj, amplitude in enumerate(amplitude_list): |
|
|
47 |
kernel = RBF(length_scale=lengthscale) |
|
|
48 |
K = amplitude * kernel(X) + 1e-8 * np.eye(n) |
|
|
49 |
X_warped = np.zeros((n, n_spatial_dims)) |
|
|
50 |
for dd in range(n_spatial_dims): |
|
|
51 |
X_warped[:, dd] = mvn.rvs(mean=X[:, dd], cov=K) |
|
|
52 |
|
|
|
53 |
plt.subplot( |
|
|
54 |
len(amplitude_list), |
|
|
55 |
len(lengthscale_list), |
|
|
56 |
ii * len(lengthscale_list) + jj + 1, |
|
|
57 |
) |
|
|
58 |
plt.scatter( |
|
|
59 |
X[:, 0], |
|
|
60 |
X[:, 1], |
|
|
61 |
color="gray", |
|
|
62 |
marker=markers[0], |
|
|
63 |
label="Original" if (jj == len(amplitude_list) - 1) and (ii == 0) else None, |
|
|
64 |
s=50, |
|
|
65 |
) |
|
|
66 |
plt.scatter( |
|
|
67 |
X_warped[:, 0], |
|
|
68 |
X_warped[:, 1], |
|
|
69 |
color="red", |
|
|
70 |
marker=markers[1], |
|
|
71 |
label="Warped" if (jj == len(amplitude_list) - 1) and (ii == 0) else None, |
|
|
72 |
s=50, |
|
|
73 |
) |
|
|
74 |
plt.title( |
|
|
75 |
r"$\ell^2$ = " |
|
|
76 |
+ str(round(lengthscale, 2)) |
|
|
77 |
+ ", $\sigma^2$ = " |
|
|
78 |
+ str(round(amplitude, 2)) |
|
|
79 |
) |
|
|
80 |
|
|
|
81 |
if (jj == len(amplitude_list) - 1) and (ii == 0): |
|
|
82 |
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) |
|
|
83 |
plt.xticks([]) |
|
|
84 |
plt.yticks([]) |
|
|
85 |
plt.tight_layout() |
|
|
86 |
plt.savefig("./out/warp_parameter_demo.png") |
|
|
87 |
plt.show() |
|
|
88 |
import ipdb |
|
|
89 |
|
|
|
90 |
ipdb.set_trace() |