Diff of /tests/_utils.py [000000] .. [6ff4a8]

Switch to unified view

a b/tests/_utils.py
1
from typing import Any, List, Optional, Tuple, Type, Union
2
3
import numpy as np
4
import pandas as pd
5
from scipy.sparse import csr_matrix
6
7
from anndata import AnnData
8
9
from moscot._types import ArrayLike
10
from moscot.base.output import MatrixSolverOutput
11
from moscot.base.problems import AnalysisMixin, CompoundProblem, OTProblem
12
from moscot.base.problems.compound_problem import B
13
14
Geom_t = Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
15
RTOL = 1e-6
16
ATOL = 1e-6
17
18
19
class CompoundProblemWithMixin(CompoundProblem, AnalysisMixin):
20
    @property
21
    def _base_problem_type(self) -> Type[B]:
22
        return OTProblem
23
24
    @property
25
    def _valid_policies(self) -> Tuple[str, ...]:
26
        return ()
27
28
29
class MockSolverOutput(MatrixSolverOutput):
30
    @property
31
    def cost(self) -> float:
32
        return 0.5
33
34
    @property
35
    def converged(self) -> bool:
36
        return True
37
38
    @property
39
    def is_linear(self) -> bool:
40
        return True
41
42
    @property
43
    def potentials(self) -> Tuple[Optional[ArrayLike], Optional[ArrayLike]]:
44
        return None, None
45
46
    def _ones(self, n: int) -> ArrayLike:
47
        return np.ones(n)
48
49
50
def _make_adata(grid: ArrayLike, n: int, seed, cat_key: str = "covariate", num_categories: int = 3) -> List[AnnData]:
51
    rng = np.random.RandomState(seed)
52
    n_cells = 100
53
    X = rng.normal(size=(n_cells, 60))
54
55
    # generate a categorical variable
56
    categories = [f"cat_{i+1}" for i in range(num_categories)]
57
    categorical_data = rng.choice(categories, size=n_cells)
58
59
    adatas = []
60
    for _ in range(n):
61
        obs_df = pd.DataFrame({cat_key: pd.Categorical(categorical_data)})
62
        adatas.append(AnnData(X=csr_matrix(X), obs=obs_df, obsm={"spatial": grid.copy()}))
63
64
    return adatas
65
66
67
def _adata_spatial_split(adata: AnnData) -> Tuple[AnnData, AnnData]:
68
    adata_ref = adata[adata.obs.batch == "0"].copy()
69
    adata_ref.obsm.pop("spatial")
70
    adata_sp = adata[adata.obs.batch != "0"].copy()
71
    return adata_ref, adata_sp
72
73
74
def _make_grid(grid_size: int) -> ArrayLike:
75
    xlimits = ylimits = [0, 10]
76
    x1s = np.linspace(*xlimits, num=grid_size)
77
    x2s = np.linspace(*ylimits, num=grid_size)
78
    X1, X2 = np.meshgrid(x1s, x2s)
79
    return np.vstack([X1.ravel(), X2.ravel()]).T
80
81
82
def _assert_marginals_set(adata_time, problem, key, marginal_keys):
83
    """Helper function to check if marginals are set correctly"""
84
    adata_time0 = adata_time[key[0] == adata_time.obs["time"]]
85
    adata_time1 = adata_time[key[1] == adata_time.obs["time"]]
86
    if marginal_keys[0] is not None:  # check if marginal keys are set
87
        a = adata_time0.obs[marginal_keys[0]].values
88
        b = adata_time1.obs[marginal_keys[1]].values
89
        assert np.allclose(problem[key].a, a)
90
        assert np.allclose(problem[key].b, b)
91
    else:  # otherwise check if marginals are uniform
92
        assert np.allclose(problem[key].a, 1.0 / adata_time0.shape[0])
93
        assert np.allclose(problem[key].b, 1.0 / adata_time1.shape[0])
94
95
96
class Problem(CompoundProblem[Any, OTProblem]):
97
    @property
98
    def _base_problem_type(self) -> Type[B]:
99
        return OTProblem
100
101
    @property
102
    def _valid_policies(self) -> Tuple[str, ...]:
103
        return ()