[54ded2]: / experiments / simulations / plot_parameter_range_results.py

Download this file

74 lines (58 with data), 2.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
import pandas as pd
from gpsa import VariationalGPSA
from gpsa import matern12_kernel, rbf_kernel
from gpsa.plotting import callback_twod
import sys
sys.path.append("../../data")
from simulated.generate_twod_data import generate_twod_data
import matplotlib
font = {"size": 30}
matplotlib.rc("font", **font)
matplotlib.rcParams["text.usetex"] = True
device = "cuda" if torch.cuda.is_available() else "cpu"
spatial_variance_errors_df = pd.read_csv(
"./out/error_experiment_parameter_range_spatial_variance.csv", index_col=0
)
lengthscale_errors_df = pd.read_csv(
"./out/error_experiment_parameter_range_lengthscale.csv", index_col=0
)
# keep_idx = np.delete(np.arange(len(spatial_variance_errors_df)), 18)
# spatial_variance_errors_df = spatial_variance_errors_df.iloc[keep_idx]
# lengthscale_errors_df = lengthscale_errors_df.iloc[keep_idx]
plt.figure(figsize=(17, 6))
spatial_variance_errors_df = spatial_variance_errors_df[
spatial_variance_errors_df.value < 1
]
## Spatial variance
plt.subplot(121)
plt.title("Spatial variance")
# sns.lineplot(data=spatial_variance_errors_df, x="variable", y="value")
sns.lineplot(data=spatial_variance_errors_df, x="variable", y="value", ci="sd")
true_spatial_variance = np.median(spatial_variance_errors_df.variable.unique())
# plt.axvline(true_spatial_variance, color="black", linestyle="--")
plt.axvline(true_spatial_variance, color="black", linestyle="--")
plt.xlabel(r"$\sigma^2$")
plt.ylabel("Error")
# plt.show()
## Length scale
plt.subplot(122)
plt.title("Length scale")
sns.lineplot(data=lengthscale_errors_df, x="variable", y="value")
true_lengthscale = np.median(lengthscale_errors_df.variable.unique())
plt.axvline(
true_lengthscale, color="black", linestyle="--", label="Data-generating\nvalue"
)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=20)
plt.xlabel(r"$\ell$")
plt.ylabel("Error")
plt.tight_layout()
plt.savefig("./out/error_experiment_parameter_range.png")
plt.show()
plt.close()
import ipdb
ipdb.set_trace()