import pickle
from math import cos, sin
from typing import List, Tuple
import numpy as np
from scipy.sparse import csr_matrix
import anndata as ad
import scanpy as sc
from anndata import AnnData
from moscot._types import ArrayLike
from moscot.problems.space import AlignmentProblem, MappingProblem
ANGLES = [0, 30, 60]
def adata_space_rotate() -> AnnData:
grid = _make_grid(10)
adatas = _make_adata(grid, n=3)
for adata, angle in zip(adatas, ANGLES):
theta = np.deg2rad(angle)
rot = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])
adata.obsm["spatial"] = np.dot(adata.obsm["spatial"], rot)
adata = ad.concat(adatas, label="batch", index_unique="-")
adata.uns["spatial"] = {}
return adata
def adata_mapping() -> AnnData:
grid = _make_grid(10)
adataref, adata1, adata2 = _make_adata(grid, n=3)
sc.pp.pca(adataref)
return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-")
def _make_grid(grid_size: int) -> ArrayLike:
xlimits = ylimits = [0, 10]
x1s = np.linspace(*xlimits, num=grid_size) # type: ignore [call-overload]
x2s = np.linspace(*ylimits, num=grid_size) # type: ignore [call-overload]
X1, X2 = np.meshgrid(x1s, x2s)
return np.vstack([X1.ravel(), X2.ravel()]).T
def _make_adata(grid: ArrayLike, n: int) -> List[AnnData]:
rng = np.random.default_rng(42)
X = rng.normal(size=(100, 60))
return [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}) for _ in range(3)]
def _adata_split(adata: AnnData) -> Tuple[AnnData, AnnData]:
adataref = adata[adata.obs["batch"] == "0"].copy()
adataref.obsm.pop("spatial")
adatasp = adata[adata.obs["batch"] != "0"].copy()
return adataref, adatasp
def generate_alignment_data() -> None:
adata = adata_space_rotate()
ap = AlignmentProblem(adata=adata)
ap = ap.prepare(batch_key="batch")
ap = ap.solve(alpha=0.5, epsilon=1)
with open("alignment_solutions.pkl", "wb") as fname:
pickle.dump(ap.solutions, fname)
def generate_mapping_data() -> None:
adata = adata_mapping()
adataref, adatasp = _adata_split(adata)
mp = MappingProblem(adataref, adatasp)
mp = mp.prepare(batch_key="batch", sc_attr={"attr": "X"})
mp = mp.solve()
with open("mapping_solutions.pkl", "wb") as fname:
pickle.dump(mp.solutions, fname)
if __name__ == "__main__":
generate_alignment_data()
generate_mapping_data()