a b/experiments/simulations/warp_parameter_demo.py
1
import torch
2
import numpy as np
3
import matplotlib.pyplot as plt
4
from sklearn.gaussian_process.kernels import RBF
5
from scipy.stats import multivariate_normal as mvn
6
7
import seaborn as sns
8
import sys
9
10
sys.path.append("../../data")
11
from simulated.generate_twod_data import generate_twod_data
12
import os
13
from os.path import join as pjoin
14
import anndata
15
16
import matplotlib
17
18
font = {"size": 25}
19
matplotlib.rc("font", **font)
20
matplotlib.rcParams["text.usetex"] = True
21
22
23
device = "cuda" if torch.cuda.is_available() else "cpu"
24
25
LATEX_FONTSIZE = 35
26
27
n_spatial_dims = 2
28
n_views = 2
29
markers = ["o", "X"]
30
31
lengthscale_list = [10**x for x in [-1, 0, 1]]
32
amplitude_list = [0.1, 1.0, 5.0]
33
34
xlimits = [0, 10]
35
ylimits = [0, 10]
36
grid_size = 10
37
x1s = np.linspace(*xlimits, num=grid_size)
38
x2s = np.linspace(*ylimits, num=grid_size)
39
X1, X2 = np.meshgrid(x1s, x2s)
40
X = np.vstack([X1.ravel(), X2.ravel()]).T
41
n = len(X)
42
43
## Lengthscales
44
plt.figure(figsize=(17, 15))
45
for ii, lengthscale in enumerate(lengthscale_list):
46
    for jj, amplitude in enumerate(amplitude_list):
47
        kernel = RBF(length_scale=lengthscale)
48
        K = amplitude * kernel(X) + 1e-8 * np.eye(n)
49
        X_warped = np.zeros((n, n_spatial_dims))
50
        for dd in range(n_spatial_dims):
51
            X_warped[:, dd] = mvn.rvs(mean=X[:, dd], cov=K)
52
53
        plt.subplot(
54
            len(amplitude_list),
55
            len(lengthscale_list),
56
            ii * len(lengthscale_list) + jj + 1,
57
        )
58
        plt.scatter(
59
            X[:, 0],
60
            X[:, 1],
61
            color="gray",
62
            marker=markers[0],
63
            label="Original" if (jj == len(amplitude_list) - 1) and (ii == 0) else None,
64
            s=50,
65
        )
66
        plt.scatter(
67
            X_warped[:, 0],
68
            X_warped[:, 1],
69
            color="red",
70
            marker=markers[1],
71
            label="Warped" if (jj == len(amplitude_list) - 1) and (ii == 0) else None,
72
            s=50,
73
        )
74
        plt.title(
75
            r"$\ell^2$ = "
76
            + str(round(lengthscale, 2))
77
            + ", $\sigma^2$ = "
78
            + str(round(amplitude, 2))
79
        )
80
81
        if (jj == len(amplitude_list) - 1) and (ii == 0):
82
            plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
83
        plt.xticks([])
84
        plt.yticks([])
85
plt.tight_layout()
86
plt.savefig("./out/warp_parameter_demo.png")
87
plt.show()
88
import ipdb
89
90
ipdb.set_trace()