[e7f7dd]: / tests / _utils.py

Download this file

104 lines (76 with data), 3.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
from typing import Any, List, Optional, Tuple, Type, Union
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from anndata import AnnData
from moscot._types import ArrayLike
from moscot.base.output import MatrixSolverOutput
from moscot.base.problems import AnalysisMixin, CompoundProblem, OTProblem
from moscot.base.problems.compound_problem import B
Geom_t = Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
RTOL = 1e-6
ATOL = 1e-6
class CompoundProblemWithMixin(CompoundProblem, AnalysisMixin):
@property
def _base_problem_type(self) -> Type[B]:
return OTProblem
@property
def _valid_policies(self) -> Tuple[str, ...]:
return ()
class MockSolverOutput(MatrixSolverOutput):
@property
def cost(self) -> float:
return 0.5
@property
def converged(self) -> bool:
return True
@property
def is_linear(self) -> bool:
return True
@property
def potentials(self) -> Tuple[Optional[ArrayLike], Optional[ArrayLike]]:
return None, None
def _ones(self, n: int) -> ArrayLike:
return np.ones(n)
def _make_adata(grid: ArrayLike, n: int, seed, cat_key: str = "covariate", num_categories: int = 3) -> List[AnnData]:
rng = np.random.RandomState(seed)
n_cells = 100
X = rng.normal(size=(n_cells, 60))
# generate a categorical variable
categories = [f"cat_{i+1}" for i in range(num_categories)]
categorical_data = rng.choice(categories, size=n_cells)
adatas = []
for _ in range(n):
obs_df = pd.DataFrame({cat_key: pd.Categorical(categorical_data)})
adatas.append(AnnData(X=csr_matrix(X), obs=obs_df, obsm={"spatial": grid.copy()}))
return adatas
def _adata_spatial_split(adata: AnnData) -> Tuple[AnnData, AnnData]:
adata_ref = adata[adata.obs.batch == "0"].copy()
adata_ref.obsm.pop("spatial")
adata_sp = adata[adata.obs.batch != "0"].copy()
return adata_ref, adata_sp
def _make_grid(grid_size: int) -> ArrayLike:
xlimits = ylimits = [0, 10]
x1s = np.linspace(*xlimits, num=grid_size)
x2s = np.linspace(*ylimits, num=grid_size)
X1, X2 = np.meshgrid(x1s, x2s)
return np.vstack([X1.ravel(), X2.ravel()]).T
def _assert_marginals_set(adata_time, problem, key, marginal_keys):
"""Helper function to check if marginals are set correctly"""
adata_time0 = adata_time[key[0] == adata_time.obs["time"]]
adata_time1 = adata_time[key[1] == adata_time.obs["time"]]
if marginal_keys[0] is not None: # check if marginal keys are set
a = adata_time0.obs[marginal_keys[0]].values
b = adata_time1.obs[marginal_keys[1]].values
assert np.allclose(problem[key].a, a)
assert np.allclose(problem[key].b, b)
else: # otherwise check if marginals are uniform
assert np.allclose(problem[key].a, 1.0 / adata_time0.shape[0])
assert np.allclose(problem[key].b, 1.0 / adata_time1.shape[0])
class Problem(CompoundProblem[Any, OTProblem]):
@property
def _base_problem_type(self) -> Type[B]:
return OTProblem
@property
def _valid_policies(self) -> Tuple[str, ...]:
return ()