a b/tests/utils/test_batching.py
1
import pytest
2
3
from edsnlp.utils.batching import (
4
    DATASET_END_SENTINEL,
5
    BatchSizeArg,
6
    FragmentEndSentinel,
7
    StreamSentinel,
8
    batchify,
9
    batchify_by_dataset,
10
    batchify_by_fragment,
11
    batchify_by_length_sum,
12
    batchify_by_padded,
13
    stat_batchify,
14
)
15
16
17
class MockStreamSentinel(StreamSentinel):
18
    pass
19
20
21
# Tests for BatchSizeArg
22
def test_batch_size_arg_validate():
23
    # Valid inputs
24
    assert BatchSizeArg.validate("10 samples") == (10, "samples")
25
    assert BatchSizeArg.validate("20 words") == (20, "words")
26
    assert BatchSizeArg.validate(15) == (15, "docs")
27
    assert BatchSizeArg.validate("docs") == (None, "docs")
28
    assert BatchSizeArg.validate("tokens") == (None, "tokens")
29
    assert BatchSizeArg.validate("25") == (25, "docs")
30
31
    # Invalid inputs
32
    with pytest.raises(Exception):
33
        BatchSizeArg.validate("invalid input")
34
    with pytest.raises(Exception):
35
        BatchSizeArg.validate("10 invalid input")
36
    with pytest.raises(Exception):
37
        BatchSizeArg.validate("invalid input 10")
38
39
40
# Tests for batchify function
41
def test_batchify_simple():
42
    data = [1, 2, 3, 4, 5]
43
    batches = list(batchify(data, batch_size=2))
44
    assert batches == [[1, 2], [3, 4], [5]]
45
46
47
def test_batchify_drop_last():
48
    data = [1, 2, 3, 4, 5]
49
    batches = list(batchify(data, batch_size=2, drop_last=True))
50
    assert batches == [[1, 2], [3, 4]]
51
52
53
def test_batchify_sentinel_drop():
54
    data = [1, 2, MockStreamSentinel(), 3, 4]
55
    batches = list(batchify(data, batch_size=2, sentinel_mode="drop"))
56
    assert batches == [[1, 2], [3, 4]]
57
58
59
def test_batchify_sentinel_keep():
60
    sentinel = MockStreamSentinel()
61
    data = [1, 2, sentinel, 3, 4]
62
    batches = list(batchify(data, batch_size=2, sentinel_mode="keep"))
63
    assert batches == [[1, 2, sentinel], [3, 4]]
64
65
66
def test_batchify_sentinel_split():
67
    sentinel = MockStreamSentinel()
68
    data = [1, 2, sentinel, 3, 4]
69
    batches = list(batchify(data, batch_size=2, sentinel_mode="split"))
70
    assert batches == [[1, 2], sentinel, [3, 4]]
71
72
73
# Tests for batchify_by_length_sum
74
def test_batchify_by_length_sum_simple():
75
    data = ["a", "bb", "ccc", "dddd", "eeeee"]
76
    batches = list(batchify_by_length_sum(data, batch_size=5))
77
    assert batches == [["a", "bb"], ["ccc"], ["dddd"], ["eeeee"]]
78
79
80
def test_batchify_by_length_sum_drop_last():
81
    data = ["a", "bb", "ccc", "dddd", "eeeee"]
82
    batches = list(batchify_by_length_sum(data, batch_size=5, drop_last=True))
83
    assert batches == [["a", "bb"], ["ccc"], ["dddd"]]
84
85
86
# Tests for batchify_by_length_sum
87
def test_batchify_by_length_sum_split():
88
    sentinel = MockStreamSentinel()
89
    data = ["aa", "bb", sentinel, "ccc", "dddd", "eeeee"]
90
    batches = list(batchify_by_length_sum(data, batch_size=7, sentinel_mode="split"))
91
    assert batches == [["aa", "bb"], sentinel, ["ccc", "dddd"], ["eeeee"]]
92
93
94
# Tests for batchify_by_length_sum
95
def test_batchify_by_length_sum_keep():
96
    sentinel = MockStreamSentinel()
97
    data = ["aa", "bb", sentinel, "ccc", "dddd", "eeeee"]
98
    batches = list(batchify_by_length_sum(data, batch_size=7, sentinel_mode="keep"))
99
    assert batches == [["aa", "bb", sentinel, "ccc"], ["dddd"], ["eeeee"]]
100
101
102
# Tests for batchify_by_padded
103
def test_batchify_by_padded_simple():
104
    data = ["a", "bb", "ccc", "dddd"]
105
    batches = list(batchify_by_padded(data, batch_size=6))
106
    assert batches == [["a", "bb"], ["ccc"], ["dddd"]]
107
108
109
def test_batchify_by_padded_drop_last():
110
    data = ["a", "bb", "ccc", "dddd"]
111
    batches = list(batchify_by_padded(data, batch_size=6, drop_last=True))
112
    assert batches == [["a", "bb"], ["ccc"]]
113
114
115
def test_batchify_by_padded_sentinel_keep():
116
    sentinel = MockStreamSentinel()
117
    data = ["a", sentinel, "bb", "ccc"]
118
    batches = list(batchify_by_padded(data, batch_size=6, sentinel_mode="keep"))
119
    assert batches == [["a", sentinel, "bb"], ["ccc"]]
120
121
122
def test_batchify_by_padded_sentinel_split():
123
    sentinel = MockStreamSentinel()
124
    data = ["a", sentinel, "bb", "ccc"]
125
    batches = list(batchify_by_padded(data, batch_size=5, sentinel_mode="split"))
126
    assert batches == [["a"], sentinel, ["bb"], ["ccc"]]
127
128
129
# Tests for batchify_by_dataset
130
def test_batchify_by_dataset_simple():
131
    data = [
132
        "item1",
133
        "item2",
134
        DATASET_END_SENTINEL,
135
        "item3",
136
        DATASET_END_SENTINEL,
137
        "item4",
138
        "item5",
139
    ]
140
    batches = list(batchify_by_dataset(data))
141
    assert batches == [
142
        ["item1", "item2"],
143
        DATASET_END_SENTINEL,
144
        ["item3"],
145
        DATASET_END_SENTINEL,
146
        ["item4", "item5"],
147
    ]
148
149
150
def test_batchify_by_dataset_sentinel_split():
151
    sentinel = MockStreamSentinel()
152
    data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"]
153
    batches = list(batchify_by_dataset(data, sentinel_mode="split"))
154
    assert batches == [["item1"], sentinel, ["item2"], DATASET_END_SENTINEL, ["item3"]]
155
156
157
def test_batchify_by_dataset_sentinel_keep():
158
    sentinel = MockStreamSentinel()
159
    data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"]
160
    batches = list(batchify_by_dataset(data, sentinel_mode="keep"))
161
    assert batches == [["item1", sentinel, "item2"], DATASET_END_SENTINEL, ["item3"]]
162
163
164
def test_batchify_by_dataset_sentinel_drop():
165
    sentinel = MockStreamSentinel()
166
    data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"]
167
    batches = list(batchify_by_dataset(data, sentinel_mode="drop"))
168
    assert batches == [["item1", "item2"], DATASET_END_SENTINEL, ["item3"]]
169
170
171
def test_batchify_by_dataset_drop_last():
172
    data = ["item1", "item2", DATASET_END_SENTINEL, "item3"]
173
    batches = list(batchify_by_dataset(data, drop_last=True))
174
    assert batches == [["item1", "item2"], DATASET_END_SENTINEL]
175
176
177
# Tests for batchify_by_fragment
178
def test_batchify_by_fragment_simple():
179
    fragment_end_1 = FragmentEndSentinel("fragment1")
180
    fragment_end_2 = FragmentEndSentinel("fragment2")
181
    data = ["item1", "item2", fragment_end_1, "item3", fragment_end_2, "item4"]
182
    batches = list(batchify_by_fragment(data))
183
    assert batches == [
184
        ["item1", "item2"],
185
        fragment_end_1,
186
        ["item3"],
187
        fragment_end_2,
188
        ["item4"],
189
    ]
190
191
192
def test_batchify_by_fragment_sentinel_split():
193
    sentinel = MockStreamSentinel()
194
    fragment_end = FragmentEndSentinel("fragment")
195
    data = ["item1", sentinel, "item2", fragment_end]
196
    batches = list(batchify_by_fragment(data, sentinel_mode="split"))
197
    assert batches == [["item1"], sentinel, ["item2"], fragment_end]
198
199
200
def test_batchify_by_fragment_sentinel_keep():
201
    sentinel = MockStreamSentinel()
202
    fragment_end = FragmentEndSentinel("fragment")
203
    data = ["item1", sentinel, "item2", fragment_end]
204
    batches = list(batchify_by_fragment(data, sentinel_mode="keep"))
205
    assert batches == [["item1", sentinel, "item2"], fragment_end]
206
207
208
def test_batchify_by_fragment_sentinel_drop():
209
    sentinel = MockStreamSentinel()
210
    fragment_end = FragmentEndSentinel("fragment")
211
    data = ["item1", sentinel, "item2", fragment_end]
212
    batches = list(batchify_by_fragment(data, sentinel_mode="drop"))
213
    assert batches == [["item1", "item2"], fragment_end]
214
215
216
def test_batchify_by_fragment_drop_last():
217
    fragment_end = FragmentEndSentinel("fragment")
218
    data = ["item1", "item2", fragment_end]
219
    batches = list(batchify_by_fragment(data, sentinel_mode="split", drop_last=True))
220
    assert batches == [["item1", "item2"], fragment_end]
221
222
223
# Tests for stat_batchify
224
def test_stat_batchify_simple():
225
    data = [
226
        {"/stats/length": 2, "text": "aa"},
227
        {"/stats/length": 3, "text": "bbb"},
228
        {"/stats/length": 4, "text": "cccc"},
229
        {"/stats/length": 2, "text": "dd"},
230
    ]
231
    batch_fn = stat_batchify("length")
232
    batches = list(batch_fn(data, batch_size=5))
233
    assert batches == [
234
        [data[0], data[1]],  # Total length: 5
235
        [data[2]],  # Total length: 4
236
        [data[3]],  # Total length: 2
237
    ]
238
239
240
def test_stat_batchify_invalid_key():
241
    data = [{"text": "aaa"}]
242
    batch_fn = stat_batchify("length")
243
    with pytest.raises(ValueError):
244
        list(batch_fn(data, batch_size=5))
245
246
247
def test_stat_batchify_sentinel_split():
248
    sentinel = MockStreamSentinel()
249
    data = [
250
        {"/stats/length": 2, "text": "aa"},
251
        sentinel,
252
        {"/stats/length": 3, "text": "bbb"},
253
    ]
254
    batch_fn = stat_batchify("length")
255
    batches = list(batch_fn(data, batch_size=5, sentinel_mode="split"))
256
    assert batches == [
257
        [data[0]],
258
        sentinel,
259
        [data[2]],
260
    ]
261
262
263
def test_stat_batchify_sentinel_keep():
264
    sentinel = MockStreamSentinel()
265
    data = [
266
        {"/stats/length": 2, "text": "aa"},
267
        sentinel,
268
        {"/stats/length": 4, "text": "bbbb"},
269
    ]
270
    batch_fn = stat_batchify("length")
271
    batches = list(batch_fn(data, batch_size=5, sentinel_mode="keep"))
272
    assert batches == [
273
        [data[0], sentinel],
274
        [data[2]],
275
    ]
276
277
278
def test_stat_batchify_drop_last():
279
    data = [
280
        {"/stats/length": 2, "text": "aa"},
281
        {"/stats/length": 3, "text": "bbb"},
282
        {"/stats/length": 4, "text": "cccc"},
283
    ]
284
    batch_fn = stat_batchify("length")
285
    batches = list(batch_fn(data, batch_size=6, drop_last=True))
286
    assert batches == [
287
        [data[0], data[1]],  # Total length: 5
288
    ]  # Last batch is dropped because total length is 4
289
290
291
# Additional tests to ensure full coverage
292
def test_batchify_empty_iterable():
293
    data = []
294
    batches = list(batchify(data, batch_size=2))
295
    assert batches == []
296
297
298
def test_batchify_by_length_sum_empty_iterable():
299
    data = []
300
    batches = list(batchify_by_length_sum(data, batch_size=5))
301
    assert batches == []
302
303
304
def test_batchify_by_padded_empty_iterable():
305
    data = []
306
    batches = list(batchify_by_padded(data, batch_size=6))
307
    assert batches == []
308
309
310
def test_batchify_by_dataset_empty_iterable():
311
    data = []
312
    batches = list(batchify_by_dataset(data))
313
    assert batches == []
314
315
316
def test_batchify_by_fragment_empty_iterable():
317
    data = []
318
    batches = list(batchify_by_fragment(data))
319
    assert batches == []
320
321
322
def test_stat_batchify_empty_iterable():
323
    data = []
324
    batch_fn = stat_batchify("length")
325
    batches = list(batch_fn(data, batch_size=5))
326
    assert batches == []
327
328
329
def test_batchify_invalid_sentinel_mode():
330
    data = [1, 2, 3]
331
    with pytest.raises(AssertionError):
332
        list(batchify(data, batch_size=2, sentinel_mode="invalid_mode"))