|
a |
|
b/tests/datasets/test_dataset.py |
|
|
1 |
from typing import Mapping, Optional |
|
|
2 |
|
|
|
3 |
import pytest |
|
|
4 |
|
|
|
5 |
import networkx as nx |
|
|
6 |
import numpy as np |
|
|
7 |
|
|
|
8 |
from moscot.datasets import simulate_data |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class TestSimulateData: |
|
|
12 |
@pytest.mark.fast |
|
|
13 |
@pytest.mark.parametrize("n_distributions", [2, 4]) |
|
|
14 |
@pytest.mark.parametrize("key", ["batch", "day"]) |
|
|
15 |
def test_n_distributions(self, n_distributions: int, key: str): |
|
|
16 |
adata = simulate_data(n_distributions=n_distributions, key=key) |
|
|
17 |
assert key in adata.obs.columns |
|
|
18 |
assert adata.obs[key].nunique() == n_distributions |
|
|
19 |
|
|
|
20 |
@pytest.mark.fast |
|
|
21 |
@pytest.mark.parametrize("obs_to_add", [{"celltype": 2}, {"celltype": 5, "cluster": 4}]) |
|
|
22 |
def test_obs_to_add(self, obs_to_add: Mapping[str, int]): |
|
|
23 |
adata = simulate_data(obs_to_add=obs_to_add) |
|
|
24 |
|
|
|
25 |
for colname, k in obs_to_add.items(): |
|
|
26 |
assert colname in adata.obs.columns |
|
|
27 |
assert adata.obs[colname].nunique() == k |
|
|
28 |
|
|
|
29 |
@pytest.mark.fast |
|
|
30 |
@pytest.mark.parametrize("spatial_dim", [None, 2, 3]) |
|
|
31 |
def test_quad_term_spatial(self, spatial_dim: Optional[int]): |
|
|
32 |
kwargs = {} |
|
|
33 |
if spatial_dim is not None: |
|
|
34 |
kwargs["spatial_dim"] = spatial_dim |
|
|
35 |
adata = simulate_data(quad_term="spatial", **kwargs) |
|
|
36 |
|
|
|
37 |
assert isinstance(adata.obsm["spatial"], np.ndarray) |
|
|
38 |
if spatial_dim is None: |
|
|
39 |
assert adata.obsm["spatial"].shape[1] == 2 |
|
|
40 |
else: |
|
|
41 |
assert adata.obsm["spatial"].shape[1] == spatial_dim |
|
|
42 |
|
|
|
43 |
@pytest.mark.fast |
|
|
44 |
@pytest.mark.parametrize("n_intBCs", [None, 4, 7]) |
|
|
45 |
@pytest.mark.parametrize("barcode_dim", [None, 5, 8]) |
|
|
46 |
def test_quad_term_barcode(self, n_intBCs: Optional[int], barcode_dim: Optional[int]): |
|
|
47 |
kwargs = {} |
|
|
48 |
if n_intBCs is not None: |
|
|
49 |
kwargs["n_intBCs"] = n_intBCs |
|
|
50 |
if barcode_dim is not None: |
|
|
51 |
kwargs["barcode_dim"] = barcode_dim |
|
|
52 |
|
|
|
53 |
adata = simulate_data(quad_term="barcode", **kwargs) |
|
|
54 |
|
|
|
55 |
assert isinstance(adata.obsm["barcode"], np.ndarray) |
|
|
56 |
if barcode_dim is None: |
|
|
57 |
assert adata.obsm["barcode"].shape[1] == 10 |
|
|
58 |
else: |
|
|
59 |
assert adata.obsm["barcode"].shape[1] == barcode_dim |
|
|
60 |
|
|
|
61 |
if n_intBCs is None: |
|
|
62 |
assert len(np.unique(adata.obsm["barcode"])) <= 20 |
|
|
63 |
else: |
|
|
64 |
assert len(np.unique(adata.obsm["barcode"])) <= n_intBCs |
|
|
65 |
|
|
|
66 |
@pytest.mark.fast |
|
|
67 |
@pytest.mark.parametrize("n_initial_nodes", [None, 4, 7]) |
|
|
68 |
@pytest.mark.parametrize("n_distributions", [3, 6]) |
|
|
69 |
def test_quad_term_tree(self, n_initial_nodes: Optional[int], n_distributions: int): |
|
|
70 |
kwargs = {} |
|
|
71 |
if n_initial_nodes is not None: |
|
|
72 |
kwargs["n_initial_nodes"] = n_initial_nodes |
|
|
73 |
adata = simulate_data(quad_term="tree", key="day", n_distributions=n_distributions, **kwargs) |
|
|
74 |
|
|
|
75 |
assert isinstance(adata.uns["trees"], dict) |
|
|
76 |
assert len(adata.uns["trees"]) == n_distributions |
|
|
77 |
|
|
|
78 |
for i in range(len(adata.uns["trees"])): |
|
|
79 |
assert isinstance(adata.uns["trees"][i], nx.DiGraph) |