|
a |
|
b/experiments/simulations/visualize_oned_warp.py |
|
|
1 |
import numpy as np |
|
|
2 |
from scipy.stats import multivariate_normal as mvn |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
import seaborn as sns |
|
|
5 |
from sklearn.gaussian_process.kernels import RBF |
|
|
6 |
|
|
|
7 |
import matplotlib |
|
|
8 |
|
|
|
9 |
font = {"size": 30} |
|
|
10 |
matplotlib.rc("font", **font) |
|
|
11 |
matplotlib.rcParams["text.usetex"] = True |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
lengthscale = 1.0 |
|
|
15 |
amplitude = 1.0 |
|
|
16 |
noise_stddev = 1e-6 |
|
|
17 |
|
|
|
18 |
xlims = [-5, 5] |
|
|
19 |
n = 100 |
|
|
20 |
X = np.linspace(xlims[0], xlims[1], n) |
|
|
21 |
X = np.expand_dims(X, 1) |
|
|
22 |
|
|
|
23 |
## Draw function |
|
|
24 |
K_XX = amplitude * RBF(length_scale=lengthscale)(X, X) + noise_stddev * np.eye(n) |
|
|
25 |
mean = X.squeeze() |
|
|
26 |
# mean = np.zeros(n) |
|
|
27 |
Y = mvn(mean, K_XX).rvs() |
|
|
28 |
|
|
|
29 |
# import ipdb; ipdb.set_trace() |
|
|
30 |
plt.figure(figsize=(7, 6)) |
|
|
31 |
plt.plot(X, Y, linewidth=5) |
|
|
32 |
plt.xlabel("Observed spatial coordinate") |
|
|
33 |
plt.ylabel("Warped spatial coordinate") |
|
|
34 |
plt.title(r"$\sigma^2 = {}, \ell = {}$".format(amplitude, lengthscale)) |
|
|
35 |
plt.tight_layout() |
|
|
36 |
plt.savefig("../../plots/mean_function_example.png") |
|
|
37 |
plt.show() |