[e7f7dd]: / tests / costs / test_utils.py

Download this file

49 lines (42 with data), 1.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pytest
from moscot.costs._utils import (
_get_available_backends_and_costs,
get_available_costs,
get_cost,
)
class TestCostUtils:
ALL_BACKENDS_N_COSTS = {
"moscot": ("barcode_distance", "leaf_distance"),
"ott": (
"euclidean",
"sq_euclidean",
"cosine",
"pnorm_p",
"sq_pnorm",
),
}
@staticmethod
def test_get_available_backends_n_costs():
assert dict(_get_available_backends_and_costs()) == {
k: list(v) for k, v in _get_available_backends_and_costs().items()
}
@staticmethod
def test_get_available_costs():
assert get_available_costs() == TestCostUtils.ALL_BACKENDS_N_COSTS
assert get_available_costs("moscot") == {"moscot": (TestCostUtils.ALL_BACKENDS_N_COSTS["moscot"])}
assert get_available_costs("ott") == {"ott": TestCostUtils.ALL_BACKENDS_N_COSTS["ott"]}
with pytest.raises(KeyError):
get_available_costs("foo")
@staticmethod
def test_get_cost_fails():
invalid_cost = "foo"
invalid_backend = "bar"
with pytest.raises(
ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{invalid_backend!r}`."
):
get_cost(invalid_cost, backend=invalid_backend)
for backend in TestCostUtils.ALL_BACKENDS_N_COSTS:
with pytest.raises(
ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{backend!r}`."
):
get_cost(invalid_cost, backend=backend)