Diff of /tests/costs/test_cost.py [000000] .. [6ff4a8]

Switch to unified view

a b/tests/costs/test_cost.py
1
import pytest
2
3
import numpy as np
4
import pandas as pd
5
6
import anndata as ad
7
8
from moscot.costs._costs import BarcodeDistance, _scaled_hamming_dist
9
from moscot.costs._utils import get_cost
10
11
12
class TestBarcodeDistance:
13
    RNG = np.random.RandomState(0)
14
15
    @staticmethod
16
    def test_barcode_distance_init():
17
        adata = ad.AnnData(TestBarcodeDistance.RNG.rand(3, 3), obsm={"barcodes": TestBarcodeDistance.RNG.rand(3, 3)})
18
        # initialization failure when no adata is provided
19
        with pytest.raises(TypeError):
20
            get_cost("barcode_distance", backend="moscot")
21
        # initialization failure when invalid key is provided
22
        with pytest.raises(KeyError):
23
            get_cost("barcode_distance", backend="moscot", adata=adata, key="invalid_key", attr="obsm")
24
        # initialization failure when invalid attr
25
        with pytest.raises(AttributeError):
26
            get_cost("barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="invalid_attr")
27
        # check if not None
28
        cost_fn: BarcodeDistance = get_cost(
29
            "barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="obsm"
30
        )
31
        assert cost_fn is not None
32
33
    @staticmethod
34
    def test_scaled_hamming_dist_with_sample_inputs():
35
        # Sample input arrays
36
        x = np.array([1, -1, 0, 1])
37
        y = np.array([0, 1, 1, 1])
38
39
        # Expected output
40
        expected_distance = 2.0 / 3
41
42
        # Compute the scaled Hamming distance
43
        computed_distance = _scaled_hamming_dist(x, y)
44
45
        # Check if the computed distance matches the expected distance
46
        np.testing.assert_almost_equal(computed_distance, expected_distance, decimal=4)
47
48
    @staticmethod
49
    def test_scaled_hamming_dist_if_nan():
50
        # Sample input arrays with no shared indices
51
        x = np.array([-1, -1, 0, 1])
52
        y = np.array([0, 1, -1, -1])
53
54
        with pytest.raises(ValueError, match="No shared indices."):
55
            _scaled_hamming_dist(x, y)
56
57
    @staticmethod
58
    def test_barcode_distance_with_sample_input():
59
        # Example barcodes
60
        barcodes = np.array([[1, 0, 1], [1, 1, 0], [0, 1, 1]])
61
62
        # Create a dummy AnnData object with the example barcodes
63
        adata = ad.AnnData(TestBarcodeDistance.RNG.rand(3, 3))
64
        adata.obsm["barcodes"] = barcodes
65
66
        # Initialize BarcodeDistance
67
        cost_fn: BarcodeDistance = get_cost(
68
            "barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="obsm"
69
        )
70
71
        # Compute distances
72
        computed_distances = cost_fn()
73
74
        # Expected distance matrix
75
        expected_distances = np.array([[0.0, 2.0, 2.0], [2.0, 0.0, 2.0], [2.0, 2.0, 0.0]]) / 3.0
76
77
        # Check if the computed distances match the expected distances
78
        np.testing.assert_almost_equal(computed_distances, expected_distances, decimal=4)
79
80
81
class TestLeafDistance:
82
    @staticmethod
83
    def create_dummy_adata_leaf():
84
        import networkx as nx
85
86
        adata: ad.AnnData = ad.AnnData(
87
            X=np.ones((10, 10)),
88
            obs=pd.DataFrame(data={"day": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2]}),
89
        )
90
        g: nx.DiGraph = nx.DiGraph()
91
        g.add_nodes_from([str(i) for i in range(3)] + ["root"])
92
        g.add_edges_from([("root", str(i)) for i in range(3)])
93
        adata.uns["tree"] = {0: g}
94
        return adata
95
96
    @staticmethod
97
    def test_leaf_distance_init():
98
        adata = TestLeafDistance.create_dummy_adata_leaf()
99
        # initialization failure when no adata is provided
100
        with pytest.raises(TypeError):
101
            get_cost("leaf_distance", backend="moscot")
102
        # initialization failure when invalid key is provided
103
        with pytest.raises(KeyError, match="Unable to find tree in"):
104
            get_cost("leaf_distance", backend="moscot", adata=adata, key="invalid_key", attr="uns", dist_key=0)
105
        # initialization failure when invalid dist_key is provided
106
        with pytest.raises(KeyError, match="Unable to find tree in"):
107
            get_cost("leaf_distance", backend="moscot", adata=adata, key="tree", attr="uns")
108
        # when leaves do not match adata.obs_names
109
        with pytest.raises(ValueError, match="Leaves do not match"):
110
            get_cost("leaf_distance", backend="moscot", adata=adata, key="tree", attr="uns", dist_key=0)()
111
        # now giving valid input
112
        adata0 = adata[adata.obs.day == 0]
113
        cost_fn = get_cost("leaf_distance", backend="moscot", adata=adata0, key="tree", attr="uns", dist_key=0)
114
        np.testing.assert_equal(cost_fn(), np.array([[0, 2, 2], [2, 0, 2], [2, 2, 0]]))
115
        # when tree is not a networkx.Graph
116
        adata0.uns["tree"] = {0: 1}
117
        with pytest.raises(TypeError, match="networkx.Graph"):
118
            get_cost("leaf_distance", backend="moscot", adata=adata0, key="tree", attr="uns", dist_key=0)