Diff of /tests/costs/test_utils.py [000000] .. [6ff4a8]

Switch to unified view

a b/tests/costs/test_utils.py
1
import pytest
2
3
from moscot.costs._utils import (
4
    _get_available_backends_and_costs,
5
    get_available_costs,
6
    get_cost,
7
)
8
9
10
class TestCostUtils:
11
    ALL_BACKENDS_N_COSTS = {
12
        "moscot": ("barcode_distance", "leaf_distance"),
13
        "ott": (
14
            "euclidean",
15
            "sq_euclidean",
16
            "cosine",
17
            "pnorm_p",
18
            "sq_pnorm",
19
        ),
20
    }
21
22
    @staticmethod
23
    def test_get_available_backends_n_costs():
24
        assert dict(_get_available_backends_and_costs()) == {
25
            k: list(v) for k, v in _get_available_backends_and_costs().items()
26
        }
27
28
    @staticmethod
29
    def test_get_available_costs():
30
        assert get_available_costs() == TestCostUtils.ALL_BACKENDS_N_COSTS
31
        assert get_available_costs("moscot") == {"moscot": (TestCostUtils.ALL_BACKENDS_N_COSTS["moscot"])}
32
        assert get_available_costs("ott") == {"ott": TestCostUtils.ALL_BACKENDS_N_COSTS["ott"]}
33
        with pytest.raises(KeyError):
34
            get_available_costs("foo")
35
36
    @staticmethod
37
    def test_get_cost_fails():
38
        invalid_cost = "foo"
39
        invalid_backend = "bar"
40
        with pytest.raises(
41
            ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{invalid_backend!r}`."
42
        ):
43
            get_cost(invalid_cost, backend=invalid_backend)
44
        for backend in TestCostUtils.ALL_BACKENDS_N_COSTS:
45
            with pytest.raises(
46
                ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{backend!r}`."
47
            ):
48
                get_cost(invalid_cost, backend=backend)