|
a |
|
b/experiments/simulations/plot_largenumspots_results.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
import seaborn as sns |
|
|
5 |
import anndata |
|
|
6 |
import pandas as pd |
|
|
7 |
|
|
|
8 |
from gpsa import VariationalGPSA |
|
|
9 |
from gpsa import matern12_kernel, rbf_kernel |
|
|
10 |
from gpsa.plotting import callback_twod |
|
|
11 |
import sys |
|
|
12 |
|
|
|
13 |
sys.path.append("../../data") |
|
|
14 |
from simulated.generate_twod_data import generate_twod_data |
|
|
15 |
|
|
|
16 |
import matplotlib |
|
|
17 |
|
|
|
18 |
font = {"size": 30} |
|
|
19 |
matplotlib.rc("font", **font) |
|
|
20 |
matplotlib.rcParams["text.usetex"] = True |
|
|
21 |
|
|
|
22 |
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
23 |
|
|
|
24 |
results_df = pd.read_csv("out/error_experiment_large_numspots.csv", index_col=0) |
|
|
25 |
# results_df = results_df[results_df.value < 0.8] |
|
|
26 |
# import ipdb; ipdb.set_trace() |
|
|
27 |
|
|
|
28 |
plt.figure(figsize=(7, 6)) |
|
|
29 |
sns.boxplot(data=results_df, x="method", y="value") |
|
|
30 |
plt.xlabel("") |
|
|
31 |
plt.ylabel("Error") |
|
|
32 |
|
|
|
33 |
plt.tight_layout() |
|
|
34 |
plt.savefig("./out/error_experiment_large_numspots.png") |
|
|
35 |
plt.show() |
|
|
36 |
plt.close() |
|
|
37 |
|
|
|
38 |
import ipdb |
|
|
39 |
|
|
|
40 |
ipdb.set_trace() |