Switch to unified view

a b/tests/datasets/test_dataset.py
1
from typing import Mapping, Optional
2
3
import pytest
4
5
import networkx as nx
6
import numpy as np
7
8
from moscot.datasets import simulate_data
9
10
11
class TestSimulateData:
12
    @pytest.mark.fast
13
    @pytest.mark.parametrize("n_distributions", [2, 4])
14
    @pytest.mark.parametrize("key", ["batch", "day"])
15
    def test_n_distributions(self, n_distributions: int, key: str):
16
        adata = simulate_data(n_distributions=n_distributions, key=key)
17
        assert key in adata.obs.columns
18
        assert adata.obs[key].nunique() == n_distributions
19
20
    @pytest.mark.fast
21
    @pytest.mark.parametrize("obs_to_add", [{"celltype": 2}, {"celltype": 5, "cluster": 4}])
22
    def test_obs_to_add(self, obs_to_add: Mapping[str, int]):
23
        adata = simulate_data(obs_to_add=obs_to_add)
24
25
        for colname, k in obs_to_add.items():
26
            assert colname in adata.obs.columns
27
            assert adata.obs[colname].nunique() == k
28
29
    @pytest.mark.fast
30
    @pytest.mark.parametrize("spatial_dim", [None, 2, 3])
31
    def test_quad_term_spatial(self, spatial_dim: Optional[int]):
32
        kwargs = {}
33
        if spatial_dim is not None:
34
            kwargs["spatial_dim"] = spatial_dim
35
        adata = simulate_data(quad_term="spatial", **kwargs)
36
37
        assert isinstance(adata.obsm["spatial"], np.ndarray)
38
        if spatial_dim is None:
39
            assert adata.obsm["spatial"].shape[1] == 2
40
        else:
41
            assert adata.obsm["spatial"].shape[1] == spatial_dim
42
43
    @pytest.mark.fast
44
    @pytest.mark.parametrize("n_intBCs", [None, 4, 7])
45
    @pytest.mark.parametrize("barcode_dim", [None, 5, 8])
46
    def test_quad_term_barcode(self, n_intBCs: Optional[int], barcode_dim: Optional[int]):
47
        kwargs = {}
48
        if n_intBCs is not None:
49
            kwargs["n_intBCs"] = n_intBCs
50
        if barcode_dim is not None:
51
            kwargs["barcode_dim"] = barcode_dim
52
53
        adata = simulate_data(quad_term="barcode", **kwargs)
54
55
        assert isinstance(adata.obsm["barcode"], np.ndarray)
56
        if barcode_dim is None:
57
            assert adata.obsm["barcode"].shape[1] == 10
58
        else:
59
            assert adata.obsm["barcode"].shape[1] == barcode_dim
60
61
        if n_intBCs is None:
62
            assert len(np.unique(adata.obsm["barcode"])) <= 20
63
        else:
64
            assert len(np.unique(adata.obsm["barcode"])) <= n_intBCs
65
66
    @pytest.mark.fast
67
    @pytest.mark.parametrize("n_initial_nodes", [None, 4, 7])
68
    @pytest.mark.parametrize("n_distributions", [3, 6])
69
    def test_quad_term_tree(self, n_initial_nodes: Optional[int], n_distributions: int):
70
        kwargs = {}
71
        if n_initial_nodes is not None:
72
            kwargs["n_initial_nodes"] = n_initial_nodes
73
        adata = simulate_data(quad_term="tree", key="day", n_distributions=n_distributions, **kwargs)
74
75
        assert isinstance(adata.uns["trees"], dict)
76
        assert len(adata.uns["trees"]) == n_distributions
77
78
        for i in range(len(adata.uns["trees"])):
79
            assert isinstance(adata.uns["trees"][i], nx.DiGraph)