Diff of /tests/_utils.py [000000] .. [6ff4a8]

Switch to side-by-side view

--- a
+++ b/tests/_utils.py
@@ -0,0 +1,103 @@
+from typing import Any, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import pandas as pd
+from scipy.sparse import csr_matrix
+
+from anndata import AnnData
+
+from moscot._types import ArrayLike
+from moscot.base.output import MatrixSolverOutput
+from moscot.base.problems import AnalysisMixin, CompoundProblem, OTProblem
+from moscot.base.problems.compound_problem import B
+
+Geom_t = Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
+RTOL = 1e-6
+ATOL = 1e-6
+
+
+class CompoundProblemWithMixin(CompoundProblem, AnalysisMixin):
+    @property
+    def _base_problem_type(self) -> Type[B]:
+        return OTProblem
+
+    @property
+    def _valid_policies(self) -> Tuple[str, ...]:
+        return ()
+
+
+class MockSolverOutput(MatrixSolverOutput):
+    @property
+    def cost(self) -> float:
+        return 0.5
+
+    @property
+    def converged(self) -> bool:
+        return True
+
+    @property
+    def is_linear(self) -> bool:
+        return True
+
+    @property
+    def potentials(self) -> Tuple[Optional[ArrayLike], Optional[ArrayLike]]:
+        return None, None
+
+    def _ones(self, n: int) -> ArrayLike:
+        return np.ones(n)
+
+
+def _make_adata(grid: ArrayLike, n: int, seed, cat_key: str = "covariate", num_categories: int = 3) -> List[AnnData]:
+    rng = np.random.RandomState(seed)
+    n_cells = 100
+    X = rng.normal(size=(n_cells, 60))
+
+    # generate a categorical variable
+    categories = [f"cat_{i+1}" for i in range(num_categories)]
+    categorical_data = rng.choice(categories, size=n_cells)
+
+    adatas = []
+    for _ in range(n):
+        obs_df = pd.DataFrame({cat_key: pd.Categorical(categorical_data)})
+        adatas.append(AnnData(X=csr_matrix(X), obs=obs_df, obsm={"spatial": grid.copy()}))
+
+    return adatas
+
+
+def _adata_spatial_split(adata: AnnData) -> Tuple[AnnData, AnnData]:
+    adata_ref = adata[adata.obs.batch == "0"].copy()
+    adata_ref.obsm.pop("spatial")
+    adata_sp = adata[adata.obs.batch != "0"].copy()
+    return adata_ref, adata_sp
+
+
+def _make_grid(grid_size: int) -> ArrayLike:
+    xlimits = ylimits = [0, 10]
+    x1s = np.linspace(*xlimits, num=grid_size)
+    x2s = np.linspace(*ylimits, num=grid_size)
+    X1, X2 = np.meshgrid(x1s, x2s)
+    return np.vstack([X1.ravel(), X2.ravel()]).T
+
+
+def _assert_marginals_set(adata_time, problem, key, marginal_keys):
+    """Helper function to check if marginals are set correctly"""
+    adata_time0 = adata_time[key[0] == adata_time.obs["time"]]
+    adata_time1 = adata_time[key[1] == adata_time.obs["time"]]
+    if marginal_keys[0] is not None:  # check if marginal keys are set
+        a = adata_time0.obs[marginal_keys[0]].values
+        b = adata_time1.obs[marginal_keys[1]].values
+        assert np.allclose(problem[key].a, a)
+        assert np.allclose(problem[key].b, b)
+    else:  # otherwise check if marginals are uniform
+        assert np.allclose(problem[key].a, 1.0 / adata_time0.shape[0])
+        assert np.allclose(problem[key].b, 1.0 / adata_time1.shape[0])
+
+
+class Problem(CompoundProblem[Any, OTProblem]):
+    @property
+    def _base_problem_type(self) -> Type[B]:
+        return OTProblem
+
+    @property
+    def _valid_policies(self) -> Tuple[str, ...]:
+        return ()