|
a |
|
b/tests/costs/test_cost.py |
|
|
1 |
import pytest |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
|
|
|
6 |
import anndata as ad |
|
|
7 |
|
|
|
8 |
from moscot.costs._costs import BarcodeDistance, _scaled_hamming_dist |
|
|
9 |
from moscot.costs._utils import get_cost |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class TestBarcodeDistance: |
|
|
13 |
RNG = np.random.RandomState(0) |
|
|
14 |
|
|
|
15 |
@staticmethod |
|
|
16 |
def test_barcode_distance_init(): |
|
|
17 |
adata = ad.AnnData(TestBarcodeDistance.RNG.rand(3, 3), obsm={"barcodes": TestBarcodeDistance.RNG.rand(3, 3)}) |
|
|
18 |
# initialization failure when no adata is provided |
|
|
19 |
with pytest.raises(TypeError): |
|
|
20 |
get_cost("barcode_distance", backend="moscot") |
|
|
21 |
# initialization failure when invalid key is provided |
|
|
22 |
with pytest.raises(KeyError): |
|
|
23 |
get_cost("barcode_distance", backend="moscot", adata=adata, key="invalid_key", attr="obsm") |
|
|
24 |
# initialization failure when invalid attr |
|
|
25 |
with pytest.raises(AttributeError): |
|
|
26 |
get_cost("barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="invalid_attr") |
|
|
27 |
# check if not None |
|
|
28 |
cost_fn: BarcodeDistance = get_cost( |
|
|
29 |
"barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="obsm" |
|
|
30 |
) |
|
|
31 |
assert cost_fn is not None |
|
|
32 |
|
|
|
33 |
@staticmethod |
|
|
34 |
def test_scaled_hamming_dist_with_sample_inputs(): |
|
|
35 |
# Sample input arrays |
|
|
36 |
x = np.array([1, -1, 0, 1]) |
|
|
37 |
y = np.array([0, 1, 1, 1]) |
|
|
38 |
|
|
|
39 |
# Expected output |
|
|
40 |
expected_distance = 2.0 / 3 |
|
|
41 |
|
|
|
42 |
# Compute the scaled Hamming distance |
|
|
43 |
computed_distance = _scaled_hamming_dist(x, y) |
|
|
44 |
|
|
|
45 |
# Check if the computed distance matches the expected distance |
|
|
46 |
np.testing.assert_almost_equal(computed_distance, expected_distance, decimal=4) |
|
|
47 |
|
|
|
48 |
@staticmethod |
|
|
49 |
def test_scaled_hamming_dist_if_nan(): |
|
|
50 |
# Sample input arrays with no shared indices |
|
|
51 |
x = np.array([-1, -1, 0, 1]) |
|
|
52 |
y = np.array([0, 1, -1, -1]) |
|
|
53 |
|
|
|
54 |
with pytest.raises(ValueError, match="No shared indices."): |
|
|
55 |
_scaled_hamming_dist(x, y) |
|
|
56 |
|
|
|
57 |
@staticmethod |
|
|
58 |
def test_barcode_distance_with_sample_input(): |
|
|
59 |
# Example barcodes |
|
|
60 |
barcodes = np.array([[1, 0, 1], [1, 1, 0], [0, 1, 1]]) |
|
|
61 |
|
|
|
62 |
# Create a dummy AnnData object with the example barcodes |
|
|
63 |
adata = ad.AnnData(TestBarcodeDistance.RNG.rand(3, 3)) |
|
|
64 |
adata.obsm["barcodes"] = barcodes |
|
|
65 |
|
|
|
66 |
# Initialize BarcodeDistance |
|
|
67 |
cost_fn: BarcodeDistance = get_cost( |
|
|
68 |
"barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="obsm" |
|
|
69 |
) |
|
|
70 |
|
|
|
71 |
# Compute distances |
|
|
72 |
computed_distances = cost_fn() |
|
|
73 |
|
|
|
74 |
# Expected distance matrix |
|
|
75 |
expected_distances = np.array([[0.0, 2.0, 2.0], [2.0, 0.0, 2.0], [2.0, 2.0, 0.0]]) / 3.0 |
|
|
76 |
|
|
|
77 |
# Check if the computed distances match the expected distances |
|
|
78 |
np.testing.assert_almost_equal(computed_distances, expected_distances, decimal=4) |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
class TestLeafDistance: |
|
|
82 |
@staticmethod |
|
|
83 |
def create_dummy_adata_leaf(): |
|
|
84 |
import networkx as nx |
|
|
85 |
|
|
|
86 |
adata: ad.AnnData = ad.AnnData( |
|
|
87 |
X=np.ones((10, 10)), |
|
|
88 |
obs=pd.DataFrame(data={"day": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2]}), |
|
|
89 |
) |
|
|
90 |
g: nx.DiGraph = nx.DiGraph() |
|
|
91 |
g.add_nodes_from([str(i) for i in range(3)] + ["root"]) |
|
|
92 |
g.add_edges_from([("root", str(i)) for i in range(3)]) |
|
|
93 |
adata.uns["tree"] = {0: g} |
|
|
94 |
return adata |
|
|
95 |
|
|
|
96 |
@staticmethod |
|
|
97 |
def test_leaf_distance_init(): |
|
|
98 |
adata = TestLeafDistance.create_dummy_adata_leaf() |
|
|
99 |
# initialization failure when no adata is provided |
|
|
100 |
with pytest.raises(TypeError): |
|
|
101 |
get_cost("leaf_distance", backend="moscot") |
|
|
102 |
# initialization failure when invalid key is provided |
|
|
103 |
with pytest.raises(KeyError, match="Unable to find tree in"): |
|
|
104 |
get_cost("leaf_distance", backend="moscot", adata=adata, key="invalid_key", attr="uns", dist_key=0) |
|
|
105 |
# initialization failure when invalid dist_key is provided |
|
|
106 |
with pytest.raises(KeyError, match="Unable to find tree in"): |
|
|
107 |
get_cost("leaf_distance", backend="moscot", adata=adata, key="tree", attr="uns") |
|
|
108 |
# when leaves do not match adata.obs_names |
|
|
109 |
with pytest.raises(ValueError, match="Leaves do not match"): |
|
|
110 |
get_cost("leaf_distance", backend="moscot", adata=adata, key="tree", attr="uns", dist_key=0)() |
|
|
111 |
# now giving valid input |
|
|
112 |
adata0 = adata[adata.obs.day == 0] |
|
|
113 |
cost_fn = get_cost("leaf_distance", backend="moscot", adata=adata0, key="tree", attr="uns", dist_key=0) |
|
|
114 |
np.testing.assert_equal(cost_fn(), np.array([[0, 2, 2], [2, 0, 2], [2, 2, 0]])) |
|
|
115 |
# when tree is not a networkx.Graph |
|
|
116 |
adata0.uns["tree"] = {0: 1} |
|
|
117 |
with pytest.raises(TypeError, match="networkx.Graph"): |
|
|
118 |
get_cost("leaf_distance", backend="moscot", adata=adata0, key="tree", attr="uns", dist_key=0) |