|
a |
|
b/tests/_utils.py |
|
|
1 |
from typing import Any, List, Optional, Tuple, Type, Union |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
from scipy.sparse import csr_matrix |
|
|
6 |
|
|
|
7 |
from anndata import AnnData |
|
|
8 |
|
|
|
9 |
from moscot._types import ArrayLike |
|
|
10 |
from moscot.base.output import MatrixSolverOutput |
|
|
11 |
from moscot.base.problems import AnalysisMixin, CompoundProblem, OTProblem |
|
|
12 |
from moscot.base.problems.compound_problem import B |
|
|
13 |
|
|
|
14 |
Geom_t = Union[np.ndarray, Tuple[np.ndarray, np.ndarray]] |
|
|
15 |
RTOL = 1e-6 |
|
|
16 |
ATOL = 1e-6 |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class CompoundProblemWithMixin(CompoundProblem, AnalysisMixin): |
|
|
20 |
@property |
|
|
21 |
def _base_problem_type(self) -> Type[B]: |
|
|
22 |
return OTProblem |
|
|
23 |
|
|
|
24 |
@property |
|
|
25 |
def _valid_policies(self) -> Tuple[str, ...]: |
|
|
26 |
return () |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
class MockSolverOutput(MatrixSolverOutput): |
|
|
30 |
@property |
|
|
31 |
def cost(self) -> float: |
|
|
32 |
return 0.5 |
|
|
33 |
|
|
|
34 |
@property |
|
|
35 |
def converged(self) -> bool: |
|
|
36 |
return True |
|
|
37 |
|
|
|
38 |
@property |
|
|
39 |
def is_linear(self) -> bool: |
|
|
40 |
return True |
|
|
41 |
|
|
|
42 |
@property |
|
|
43 |
def potentials(self) -> Tuple[Optional[ArrayLike], Optional[ArrayLike]]: |
|
|
44 |
return None, None |
|
|
45 |
|
|
|
46 |
def _ones(self, n: int) -> ArrayLike: |
|
|
47 |
return np.ones(n) |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def _make_adata(grid: ArrayLike, n: int, seed, cat_key: str = "covariate", num_categories: int = 3) -> List[AnnData]: |
|
|
51 |
rng = np.random.RandomState(seed) |
|
|
52 |
n_cells = 100 |
|
|
53 |
X = rng.normal(size=(n_cells, 60)) |
|
|
54 |
|
|
|
55 |
# generate a categorical variable |
|
|
56 |
categories = [f"cat_{i+1}" for i in range(num_categories)] |
|
|
57 |
categorical_data = rng.choice(categories, size=n_cells) |
|
|
58 |
|
|
|
59 |
adatas = [] |
|
|
60 |
for _ in range(n): |
|
|
61 |
obs_df = pd.DataFrame({cat_key: pd.Categorical(categorical_data)}) |
|
|
62 |
adatas.append(AnnData(X=csr_matrix(X), obs=obs_df, obsm={"spatial": grid.copy()})) |
|
|
63 |
|
|
|
64 |
return adatas |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def _adata_spatial_split(adata: AnnData) -> Tuple[AnnData, AnnData]: |
|
|
68 |
adata_ref = adata[adata.obs.batch == "0"].copy() |
|
|
69 |
adata_ref.obsm.pop("spatial") |
|
|
70 |
adata_sp = adata[adata.obs.batch != "0"].copy() |
|
|
71 |
return adata_ref, adata_sp |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def _make_grid(grid_size: int) -> ArrayLike: |
|
|
75 |
xlimits = ylimits = [0, 10] |
|
|
76 |
x1s = np.linspace(*xlimits, num=grid_size) |
|
|
77 |
x2s = np.linspace(*ylimits, num=grid_size) |
|
|
78 |
X1, X2 = np.meshgrid(x1s, x2s) |
|
|
79 |
return np.vstack([X1.ravel(), X2.ravel()]).T |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
def _assert_marginals_set(adata_time, problem, key, marginal_keys): |
|
|
83 |
"""Helper function to check if marginals are set correctly""" |
|
|
84 |
adata_time0 = adata_time[key[0] == adata_time.obs["time"]] |
|
|
85 |
adata_time1 = adata_time[key[1] == adata_time.obs["time"]] |
|
|
86 |
if marginal_keys[0] is not None: # check if marginal keys are set |
|
|
87 |
a = adata_time0.obs[marginal_keys[0]].values |
|
|
88 |
b = adata_time1.obs[marginal_keys[1]].values |
|
|
89 |
assert np.allclose(problem[key].a, a) |
|
|
90 |
assert np.allclose(problem[key].b, b) |
|
|
91 |
else: # otherwise check if marginals are uniform |
|
|
92 |
assert np.allclose(problem[key].a, 1.0 / adata_time0.shape[0]) |
|
|
93 |
assert np.allclose(problem[key].b, 1.0 / adata_time1.shape[0]) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
class Problem(CompoundProblem[Any, OTProblem]): |
|
|
97 |
@property |
|
|
98 |
def _base_problem_type(self) -> Type[B]: |
|
|
99 |
return OTProblem |
|
|
100 |
|
|
|
101 |
@property |
|
|
102 |
def _valid_policies(self) -> Tuple[str, ...]: |
|
|
103 |
return () |