|
a |
|
b/tests/costs/test_utils.py |
|
|
1 |
import pytest |
|
|
2 |
|
|
|
3 |
from moscot.costs._utils import ( |
|
|
4 |
_get_available_backends_and_costs, |
|
|
5 |
get_available_costs, |
|
|
6 |
get_cost, |
|
|
7 |
) |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class TestCostUtils: |
|
|
11 |
ALL_BACKENDS_N_COSTS = { |
|
|
12 |
"moscot": ("barcode_distance", "leaf_distance"), |
|
|
13 |
"ott": ( |
|
|
14 |
"euclidean", |
|
|
15 |
"sq_euclidean", |
|
|
16 |
"cosine", |
|
|
17 |
"pnorm_p", |
|
|
18 |
"sq_pnorm", |
|
|
19 |
), |
|
|
20 |
} |
|
|
21 |
|
|
|
22 |
@staticmethod |
|
|
23 |
def test_get_available_backends_n_costs(): |
|
|
24 |
assert dict(_get_available_backends_and_costs()) == { |
|
|
25 |
k: list(v) for k, v in _get_available_backends_and_costs().items() |
|
|
26 |
} |
|
|
27 |
|
|
|
28 |
@staticmethod |
|
|
29 |
def test_get_available_costs(): |
|
|
30 |
assert get_available_costs() == TestCostUtils.ALL_BACKENDS_N_COSTS |
|
|
31 |
assert get_available_costs("moscot") == {"moscot": (TestCostUtils.ALL_BACKENDS_N_COSTS["moscot"])} |
|
|
32 |
assert get_available_costs("ott") == {"ott": TestCostUtils.ALL_BACKENDS_N_COSTS["ott"]} |
|
|
33 |
with pytest.raises(KeyError): |
|
|
34 |
get_available_costs("foo") |
|
|
35 |
|
|
|
36 |
@staticmethod |
|
|
37 |
def test_get_cost_fails(): |
|
|
38 |
invalid_cost = "foo" |
|
|
39 |
invalid_backend = "bar" |
|
|
40 |
with pytest.raises( |
|
|
41 |
ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{invalid_backend!r}`." |
|
|
42 |
): |
|
|
43 |
get_cost(invalid_cost, backend=invalid_backend) |
|
|
44 |
for backend in TestCostUtils.ALL_BACKENDS_N_COSTS: |
|
|
45 |
with pytest.raises( |
|
|
46 |
ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{backend!r}`." |
|
|
47 |
): |
|
|
48 |
get_cost(invalid_cost, backend=backend) |