Switch to unified view

a b/tests/data/regression_tests_spatial.py
1
import pickle
2
from math import cos, sin
3
from typing import List, Tuple
4
5
import numpy as np
6
from scipy.sparse import csr_matrix
7
8
import anndata as ad
9
import scanpy as sc
10
from anndata import AnnData
11
12
from moscot._types import ArrayLike
13
from moscot.problems.space import AlignmentProblem, MappingProblem
14
15
ANGLES = [0, 30, 60]
16
17
18
def adata_space_rotate() -> AnnData:
19
    grid = _make_grid(10)
20
    adatas = _make_adata(grid, n=3)
21
    for adata, angle in zip(adatas, ANGLES):
22
        theta = np.deg2rad(angle)
23
        rot = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])
24
        adata.obsm["spatial"] = np.dot(adata.obsm["spatial"], rot)
25
26
    adata = ad.concat(adatas, label="batch", index_unique="-")
27
    adata.uns["spatial"] = {}
28
    return adata
29
30
31
def adata_mapping() -> AnnData:
32
    grid = _make_grid(10)
33
    adataref, adata1, adata2 = _make_adata(grid, n=3)
34
    sc.pp.pca(adataref)
35
36
    return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-")
37
38
39
def _make_grid(grid_size: int) -> ArrayLike:
40
    xlimits = ylimits = [0, 10]
41
    x1s = np.linspace(*xlimits, num=grid_size)  # type: ignore [call-overload]
42
    x2s = np.linspace(*ylimits, num=grid_size)  # type: ignore [call-overload]
43
    X1, X2 = np.meshgrid(x1s, x2s)
44
    return np.vstack([X1.ravel(), X2.ravel()]).T
45
46
47
def _make_adata(grid: ArrayLike, n: int) -> List[AnnData]:
48
    rng = np.random.default_rng(42)
49
    X = rng.normal(size=(100, 60))
50
    return [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}) for _ in range(3)]
51
52
53
def _adata_split(adata: AnnData) -> Tuple[AnnData, AnnData]:
54
    adataref = adata[adata.obs["batch"] == "0"].copy()
55
    adataref.obsm.pop("spatial")
56
    adatasp = adata[adata.obs["batch"] != "0"].copy()
57
    return adataref, adatasp
58
59
60
def generate_alignment_data() -> None:
61
    adata = adata_space_rotate()
62
    ap = AlignmentProblem(adata=adata)
63
    ap = ap.prepare(batch_key="batch")
64
    ap = ap.solve(alpha=0.5, epsilon=1)
65
66
    with open("alignment_solutions.pkl", "wb") as fname:
67
        pickle.dump(ap.solutions, fname)
68
69
70
def generate_mapping_data() -> None:
71
    adata = adata_mapping()
72
    adataref, adatasp = _adata_split(adata)
73
    mp = MappingProblem(adataref, adatasp)
74
    mp = mp.prepare(batch_key="batch", sc_attr={"attr": "X"})
75
    mp = mp.solve()
76
    with open("mapping_solutions.pkl", "wb") as fname:
77
        pickle.dump(mp.solutions, fname)
78
79
80
if __name__ == "__main__":
81
    generate_alignment_data()
82
    generate_mapping_data()