Switch to side-by-side view

--- a
+++ b/tests/utils/test_batching.py
@@ -0,0 +1,332 @@
+import pytest
+
+from edsnlp.utils.batching import (
+    DATASET_END_SENTINEL,
+    BatchSizeArg,
+    FragmentEndSentinel,
+    StreamSentinel,
+    batchify,
+    batchify_by_dataset,
+    batchify_by_fragment,
+    batchify_by_length_sum,
+    batchify_by_padded,
+    stat_batchify,
+)
+
+
+class MockStreamSentinel(StreamSentinel):
+    pass
+
+
+# Tests for BatchSizeArg
+def test_batch_size_arg_validate():
+    # Valid inputs
+    assert BatchSizeArg.validate("10 samples") == (10, "samples")
+    assert BatchSizeArg.validate("20 words") == (20, "words")
+    assert BatchSizeArg.validate(15) == (15, "docs")
+    assert BatchSizeArg.validate("docs") == (None, "docs")
+    assert BatchSizeArg.validate("tokens") == (None, "tokens")
+    assert BatchSizeArg.validate("25") == (25, "docs")
+
+    # Invalid inputs
+    with pytest.raises(Exception):
+        BatchSizeArg.validate("invalid input")
+    with pytest.raises(Exception):
+        BatchSizeArg.validate("10 invalid input")
+    with pytest.raises(Exception):
+        BatchSizeArg.validate("invalid input 10")
+
+
+# Tests for batchify function
+def test_batchify_simple():
+    data = [1, 2, 3, 4, 5]
+    batches = list(batchify(data, batch_size=2))
+    assert batches == [[1, 2], [3, 4], [5]]
+
+
+def test_batchify_drop_last():
+    data = [1, 2, 3, 4, 5]
+    batches = list(batchify(data, batch_size=2, drop_last=True))
+    assert batches == [[1, 2], [3, 4]]
+
+
+def test_batchify_sentinel_drop():
+    data = [1, 2, MockStreamSentinel(), 3, 4]
+    batches = list(batchify(data, batch_size=2, sentinel_mode="drop"))
+    assert batches == [[1, 2], [3, 4]]
+
+
+def test_batchify_sentinel_keep():
+    sentinel = MockStreamSentinel()
+    data = [1, 2, sentinel, 3, 4]
+    batches = list(batchify(data, batch_size=2, sentinel_mode="keep"))
+    assert batches == [[1, 2, sentinel], [3, 4]]
+
+
+def test_batchify_sentinel_split():
+    sentinel = MockStreamSentinel()
+    data = [1, 2, sentinel, 3, 4]
+    batches = list(batchify(data, batch_size=2, sentinel_mode="split"))
+    assert batches == [[1, 2], sentinel, [3, 4]]
+
+
+# Tests for batchify_by_length_sum
+def test_batchify_by_length_sum_simple():
+    data = ["a", "bb", "ccc", "dddd", "eeeee"]
+    batches = list(batchify_by_length_sum(data, batch_size=5))
+    assert batches == [["a", "bb"], ["ccc"], ["dddd"], ["eeeee"]]
+
+
+def test_batchify_by_length_sum_drop_last():
+    data = ["a", "bb", "ccc", "dddd", "eeeee"]
+    batches = list(batchify_by_length_sum(data, batch_size=5, drop_last=True))
+    assert batches == [["a", "bb"], ["ccc"], ["dddd"]]
+
+
+# Tests for batchify_by_length_sum
+def test_batchify_by_length_sum_split():
+    sentinel = MockStreamSentinel()
+    data = ["aa", "bb", sentinel, "ccc", "dddd", "eeeee"]
+    batches = list(batchify_by_length_sum(data, batch_size=7, sentinel_mode="split"))
+    assert batches == [["aa", "bb"], sentinel, ["ccc", "dddd"], ["eeeee"]]
+
+
+# Tests for batchify_by_length_sum
+def test_batchify_by_length_sum_keep():
+    sentinel = MockStreamSentinel()
+    data = ["aa", "bb", sentinel, "ccc", "dddd", "eeeee"]
+    batches = list(batchify_by_length_sum(data, batch_size=7, sentinel_mode="keep"))
+    assert batches == [["aa", "bb", sentinel, "ccc"], ["dddd"], ["eeeee"]]
+
+
+# Tests for batchify_by_padded
+def test_batchify_by_padded_simple():
+    data = ["a", "bb", "ccc", "dddd"]
+    batches = list(batchify_by_padded(data, batch_size=6))
+    assert batches == [["a", "bb"], ["ccc"], ["dddd"]]
+
+
+def test_batchify_by_padded_drop_last():
+    data = ["a", "bb", "ccc", "dddd"]
+    batches = list(batchify_by_padded(data, batch_size=6, drop_last=True))
+    assert batches == [["a", "bb"], ["ccc"]]
+
+
+def test_batchify_by_padded_sentinel_keep():
+    sentinel = MockStreamSentinel()
+    data = ["a", sentinel, "bb", "ccc"]
+    batches = list(batchify_by_padded(data, batch_size=6, sentinel_mode="keep"))
+    assert batches == [["a", sentinel, "bb"], ["ccc"]]
+
+
+def test_batchify_by_padded_sentinel_split():
+    sentinel = MockStreamSentinel()
+    data = ["a", sentinel, "bb", "ccc"]
+    batches = list(batchify_by_padded(data, batch_size=5, sentinel_mode="split"))
+    assert batches == [["a"], sentinel, ["bb"], ["ccc"]]
+
+
+# Tests for batchify_by_dataset
+def test_batchify_by_dataset_simple():
+    data = [
+        "item1",
+        "item2",
+        DATASET_END_SENTINEL,
+        "item3",
+        DATASET_END_SENTINEL,
+        "item4",
+        "item5",
+    ]
+    batches = list(batchify_by_dataset(data))
+    assert batches == [
+        ["item1", "item2"],
+        DATASET_END_SENTINEL,
+        ["item3"],
+        DATASET_END_SENTINEL,
+        ["item4", "item5"],
+    ]
+
+
+def test_batchify_by_dataset_sentinel_split():
+    sentinel = MockStreamSentinel()
+    data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"]
+    batches = list(batchify_by_dataset(data, sentinel_mode="split"))
+    assert batches == [["item1"], sentinel, ["item2"], DATASET_END_SENTINEL, ["item3"]]
+
+
+def test_batchify_by_dataset_sentinel_keep():
+    sentinel = MockStreamSentinel()
+    data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"]
+    batches = list(batchify_by_dataset(data, sentinel_mode="keep"))
+    assert batches == [["item1", sentinel, "item2"], DATASET_END_SENTINEL, ["item3"]]
+
+
+def test_batchify_by_dataset_sentinel_drop():
+    sentinel = MockStreamSentinel()
+    data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"]
+    batches = list(batchify_by_dataset(data, sentinel_mode="drop"))
+    assert batches == [["item1", "item2"], DATASET_END_SENTINEL, ["item3"]]
+
+
+def test_batchify_by_dataset_drop_last():
+    data = ["item1", "item2", DATASET_END_SENTINEL, "item3"]
+    batches = list(batchify_by_dataset(data, drop_last=True))
+    assert batches == [["item1", "item2"], DATASET_END_SENTINEL]
+
+
+# Tests for batchify_by_fragment
+def test_batchify_by_fragment_simple():
+    fragment_end_1 = FragmentEndSentinel("fragment1")
+    fragment_end_2 = FragmentEndSentinel("fragment2")
+    data = ["item1", "item2", fragment_end_1, "item3", fragment_end_2, "item4"]
+    batches = list(batchify_by_fragment(data))
+    assert batches == [
+        ["item1", "item2"],
+        fragment_end_1,
+        ["item3"],
+        fragment_end_2,
+        ["item4"],
+    ]
+
+
+def test_batchify_by_fragment_sentinel_split():
+    sentinel = MockStreamSentinel()
+    fragment_end = FragmentEndSentinel("fragment")
+    data = ["item1", sentinel, "item2", fragment_end]
+    batches = list(batchify_by_fragment(data, sentinel_mode="split"))
+    assert batches == [["item1"], sentinel, ["item2"], fragment_end]
+
+
+def test_batchify_by_fragment_sentinel_keep():
+    sentinel = MockStreamSentinel()
+    fragment_end = FragmentEndSentinel("fragment")
+    data = ["item1", sentinel, "item2", fragment_end]
+    batches = list(batchify_by_fragment(data, sentinel_mode="keep"))
+    assert batches == [["item1", sentinel, "item2"], fragment_end]
+
+
+def test_batchify_by_fragment_sentinel_drop():
+    sentinel = MockStreamSentinel()
+    fragment_end = FragmentEndSentinel("fragment")
+    data = ["item1", sentinel, "item2", fragment_end]
+    batches = list(batchify_by_fragment(data, sentinel_mode="drop"))
+    assert batches == [["item1", "item2"], fragment_end]
+
+
+def test_batchify_by_fragment_drop_last():
+    fragment_end = FragmentEndSentinel("fragment")
+    data = ["item1", "item2", fragment_end]
+    batches = list(batchify_by_fragment(data, sentinel_mode="split", drop_last=True))
+    assert batches == [["item1", "item2"], fragment_end]
+
+
+# Tests for stat_batchify
+def test_stat_batchify_simple():
+    data = [
+        {"/stats/length": 2, "text": "aa"},
+        {"/stats/length": 3, "text": "bbb"},
+        {"/stats/length": 4, "text": "cccc"},
+        {"/stats/length": 2, "text": "dd"},
+    ]
+    batch_fn = stat_batchify("length")
+    batches = list(batch_fn(data, batch_size=5))
+    assert batches == [
+        [data[0], data[1]],  # Total length: 5
+        [data[2]],  # Total length: 4
+        [data[3]],  # Total length: 2
+    ]
+
+
+def test_stat_batchify_invalid_key():
+    data = [{"text": "aaa"}]
+    batch_fn = stat_batchify("length")
+    with pytest.raises(ValueError):
+        list(batch_fn(data, batch_size=5))
+
+
+def test_stat_batchify_sentinel_split():
+    sentinel = MockStreamSentinel()
+    data = [
+        {"/stats/length": 2, "text": "aa"},
+        sentinel,
+        {"/stats/length": 3, "text": "bbb"},
+    ]
+    batch_fn = stat_batchify("length")
+    batches = list(batch_fn(data, batch_size=5, sentinel_mode="split"))
+    assert batches == [
+        [data[0]],
+        sentinel,
+        [data[2]],
+    ]
+
+
+def test_stat_batchify_sentinel_keep():
+    sentinel = MockStreamSentinel()
+    data = [
+        {"/stats/length": 2, "text": "aa"},
+        sentinel,
+        {"/stats/length": 4, "text": "bbbb"},
+    ]
+    batch_fn = stat_batchify("length")
+    batches = list(batch_fn(data, batch_size=5, sentinel_mode="keep"))
+    assert batches == [
+        [data[0], sentinel],
+        [data[2]],
+    ]
+
+
+def test_stat_batchify_drop_last():
+    data = [
+        {"/stats/length": 2, "text": "aa"},
+        {"/stats/length": 3, "text": "bbb"},
+        {"/stats/length": 4, "text": "cccc"},
+    ]
+    batch_fn = stat_batchify("length")
+    batches = list(batch_fn(data, batch_size=6, drop_last=True))
+    assert batches == [
+        [data[0], data[1]],  # Total length: 5
+    ]  # Last batch is dropped because total length is 4
+
+
+# Additional tests to ensure full coverage
+def test_batchify_empty_iterable():
+    data = []
+    batches = list(batchify(data, batch_size=2))
+    assert batches == []
+
+
+def test_batchify_by_length_sum_empty_iterable():
+    data = []
+    batches = list(batchify_by_length_sum(data, batch_size=5))
+    assert batches == []
+
+
+def test_batchify_by_padded_empty_iterable():
+    data = []
+    batches = list(batchify_by_padded(data, batch_size=6))
+    assert batches == []
+
+
+def test_batchify_by_dataset_empty_iterable():
+    data = []
+    batches = list(batchify_by_dataset(data))
+    assert batches == []
+
+
+def test_batchify_by_fragment_empty_iterable():
+    data = []
+    batches = list(batchify_by_fragment(data))
+    assert batches == []
+
+
+def test_stat_batchify_empty_iterable():
+    data = []
+    batch_fn = stat_batchify("length")
+    batches = list(batch_fn(data, batch_size=5))
+    assert batches == []
+
+
+def test_batchify_invalid_sentinel_mode():
+    data = [1, 2, 3]
+    with pytest.raises(AssertionError):
+        list(batchify(data, batch_size=2, sentinel_mode="invalid_mode"))