|
a |
|
b/tests/data/regression_tests_spatial.py |
|
|
1 |
import pickle |
|
|
2 |
from math import cos, sin |
|
|
3 |
from typing import List, Tuple |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
from scipy.sparse import csr_matrix |
|
|
7 |
|
|
|
8 |
import anndata as ad |
|
|
9 |
import scanpy as sc |
|
|
10 |
from anndata import AnnData |
|
|
11 |
|
|
|
12 |
from moscot._types import ArrayLike |
|
|
13 |
from moscot.problems.space import AlignmentProblem, MappingProblem |
|
|
14 |
|
|
|
15 |
ANGLES = [0, 30, 60] |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def adata_space_rotate() -> AnnData: |
|
|
19 |
grid = _make_grid(10) |
|
|
20 |
adatas = _make_adata(grid, n=3) |
|
|
21 |
for adata, angle in zip(adatas, ANGLES): |
|
|
22 |
theta = np.deg2rad(angle) |
|
|
23 |
rot = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]]) |
|
|
24 |
adata.obsm["spatial"] = np.dot(adata.obsm["spatial"], rot) |
|
|
25 |
|
|
|
26 |
adata = ad.concat(adatas, label="batch", index_unique="-") |
|
|
27 |
adata.uns["spatial"] = {} |
|
|
28 |
return adata |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def adata_mapping() -> AnnData: |
|
|
32 |
grid = _make_grid(10) |
|
|
33 |
adataref, adata1, adata2 = _make_adata(grid, n=3) |
|
|
34 |
sc.pp.pca(adataref) |
|
|
35 |
|
|
|
36 |
return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-") |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def _make_grid(grid_size: int) -> ArrayLike: |
|
|
40 |
xlimits = ylimits = [0, 10] |
|
|
41 |
x1s = np.linspace(*xlimits, num=grid_size) # type: ignore [call-overload] |
|
|
42 |
x2s = np.linspace(*ylimits, num=grid_size) # type: ignore [call-overload] |
|
|
43 |
X1, X2 = np.meshgrid(x1s, x2s) |
|
|
44 |
return np.vstack([X1.ravel(), X2.ravel()]).T |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def _make_adata(grid: ArrayLike, n: int) -> List[AnnData]: |
|
|
48 |
rng = np.random.default_rng(42) |
|
|
49 |
X = rng.normal(size=(100, 60)) |
|
|
50 |
return [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}) for _ in range(3)] |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def _adata_split(adata: AnnData) -> Tuple[AnnData, AnnData]: |
|
|
54 |
adataref = adata[adata.obs["batch"] == "0"].copy() |
|
|
55 |
adataref.obsm.pop("spatial") |
|
|
56 |
adatasp = adata[adata.obs["batch"] != "0"].copy() |
|
|
57 |
return adataref, adatasp |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def generate_alignment_data() -> None: |
|
|
61 |
adata = adata_space_rotate() |
|
|
62 |
ap = AlignmentProblem(adata=adata) |
|
|
63 |
ap = ap.prepare(batch_key="batch") |
|
|
64 |
ap = ap.solve(alpha=0.5, epsilon=1) |
|
|
65 |
|
|
|
66 |
with open("alignment_solutions.pkl", "wb") as fname: |
|
|
67 |
pickle.dump(ap.solutions, fname) |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
def generate_mapping_data() -> None: |
|
|
71 |
adata = adata_mapping() |
|
|
72 |
adataref, adatasp = _adata_split(adata) |
|
|
73 |
mp = MappingProblem(adataref, adatasp) |
|
|
74 |
mp = mp.prepare(batch_key="batch", sc_attr={"attr": "X"}) |
|
|
75 |
mp = mp.solve() |
|
|
76 |
with open("mapping_solutions.pkl", "wb") as fname: |
|
|
77 |
pickle.dump(mp.solutions, fname) |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
if __name__ == "__main__": |
|
|
81 |
generate_alignment_data() |
|
|
82 |
generate_mapping_data() |