a b/tests/utils/test_collections.py
1
from edsnlp.utils.collections import (
2
    batch_compress_dict,
3
    batchify,
4
    decompress_dict,
5
    dl_to_ld,
6
    flatten,
7
    flatten_once,
8
    get_deep_attr,
9
    ld_to_dl,
10
    multi_tee,
11
    set_deep_attr,
12
)
13
14
15
def test_multi_tee():
16
    gen = (i for i in range(10))
17
    tee = multi_tee(gen)
18
    items1 = [value for i, value in zip(tee, range(5))]
19
    items2 = [value for i, value in zip(tee, range(5))]
20
    assert items1 == items2 == [0, 1, 2, 3, 4]
21
22
    # not the behavior I'd like (continue from 5 would be nice) but at least
23
    # the generator is not exhausted
24
    assert next(gen) == 6
25
26
    assert multi_tee(items1) is items1
27
28
29
def test_flatten():
30
    items = [1, [2, 3], [[4, 5], 6], [[[7, 8], 9], 10]]
31
    assert list(flatten_once(items)) == [1, 2, 3, [4, 5], 6, [[7, 8], 9], 10]
32
    assert list(flatten(items)) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
33
34
35
def test_dict_of_lists():
36
    items = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
37
    assert ld_to_dl(items) == {"a": [1, 3], "b": [2, 4]}
38
    assert list(dl_to_ld({"a": [1, 3], "b": [2, 4]})) == items
39
40
41
def test_dict_compression():
42
    list1 = [1, 2, 3]
43
    mapping1 = {"a": list1, "b": {"c": list1, "d": 4}}
44
    list2 = [1, 2, 3]
45
    mapping2 = {"a": list2, "b": {"c": list2, "d": 4}}
46
47
    samples = [mapping1, mapping2]
48
    assert list(batch_compress_dict(samples)) == [
49
        {"a|b/c": [1, 2, 3], "b/d": 4},
50
        {"a|b/c": [1, 2, 3], "b/d": 4},
51
    ]
52
53
    res = decompress_dict(
54
        [
55
            {"a|b/c": [1, 2, 3], "b/d": 4},
56
            {"a|b/c": [1, 2, 3], "b/d": 4},
57
        ]
58
    )
59
    assert res == {
60
        "a": [[1, 2, 3], [1, 2, 3]],
61
        "b": {"c": [[1, 2, 3], [1, 2, 3]], "d": [4, 4]},
62
    }
63
    assert res["a"] is res["b"]["c"]
64
65
66
def test_batchify():
67
    items = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
68
    batches = list(batchify(items, 3))
69
    assert len(batches) == 4
70
    batches = list(batches)
71
    assert batches == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
72
73
74
def test_deep_path():
75
    class custom:
76
        def __init__(self, values):
77
            self.values = values
78
79
    item = {"a": {"b": (0, 1), "other": custom((1, 2))}}
80
    assert get_deep_attr(item, "a.b.0") == 0
81
    assert get_deep_attr(item, "a.other.values.0") == 1
82
    set_deep_attr(item, "a.b.1", 2)
83
    set_deep_attr(item, "a.other.values.0", 1000)
84
    assert item["a"]["b"] == (0, 2)
85
    assert item["a"]["other"].values == (1000, 2)