--- a +++ b/tests/problems/time/conftest.py @@ -0,0 +1,42 @@ +import pytest + +import numpy as np +import scipy + +from anndata import AnnData + +from moscot.datasets import _get_random_trees + + +@pytest.fixture +def adata_time_trees(adata_time: AnnData) -> AnnData: + trees = _get_random_trees( + n_leaves=96, n_trees=3, leaf_names=[list(adata_time[adata_time.obs.time == i].obs.index) for i in range(3)] + ) + adata_time.uns["trees"] = {0: trees[0], 1: trees[1], 2: trees[2]} + return adata_time + + +@pytest.fixture +def adata_time_custom_cost_xy(adata_time: AnnData) -> AnnData: + rng = np.random.RandomState(42) + cost_m1 = np.abs(rng.randn(96, 96)) + cost_m2 = np.abs(rng.randn(96, 96)) + cost_m3 = np.abs(rng.randn(96, 96)) + adata_time.obsp["cost_matrices"] = scipy.sparse.csr_matrix(scipy.linalg.block_diag(cost_m1, cost_m2, cost_m3)) + return adata_time + + +@pytest.fixture +def adata_time_barcodes(adata_time: AnnData) -> AnnData: + rng = np.random.RandomState(42) + adata_time.obsm["barcodes"] = rng.randn(len(adata_time), 30) + return adata_time + + +@pytest.fixture +def adata_time_marginal_estimations(adata_time: AnnData) -> AnnData: + rng = np.random.RandomState(42) + adata_time.obs["proliferation"] = rng.randn(len(adata_time)) + adata_time.obs["apoptosis"] = rng.randn(len(adata_time)) + return adata_time