[54ded2]: / experiments / simulations / two_dimensional_fixed_view.py

Download this file

56 lines (40 with data), 1.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import sys
from two_dimensional import two_d_gpsa
sys.path.append("../..")
from models.gpsa_vi_lmc import VariationalWarpGP
sys.path.append("../../data")
from simulated.generate_twod_data import generate_twod_data
from plotting.callbacks import callback_twod
from util import ConvergenceChecker
## For PASTE
import scanpy as sc
sys.path.append("../../../paste")
from src.paste import PASTE, visualization
device = "cuda" if torch.cuda.is_available() else "cpu"
LATEX_FONTSIZE = 50
n_spatial_dims = 2
n_views = 2
n_outputs = 10
# m_G = 40
# m_X_per_view = 40
MAX_EPOCHS = 2000
PRINT_EVERY = 25
N_LATENT_GPS = {"expression": 5}
if __name__ == "__main__":
n_repeats = 3
for ii in range(n_repeats):
X, Y, G_means, model, err_paste = two_d_gpsa(
n_outputs=n_outputs,
n_epochs=MAX_EPOCHS,
plot_intermediate=True,
warp_kernel_variance=0.5,
n_latent_gps=N_LATENT_GPS,
fixed_view_idx=0,
)
import ipdb
ipdb.set_trace()