--- a +++ b/tests/utils/test_batching.py @@ -0,0 +1,332 @@ +import pytest + +from edsnlp.utils.batching import ( + DATASET_END_SENTINEL, + BatchSizeArg, + FragmentEndSentinel, + StreamSentinel, + batchify, + batchify_by_dataset, + batchify_by_fragment, + batchify_by_length_sum, + batchify_by_padded, + stat_batchify, +) + + +class MockStreamSentinel(StreamSentinel): + pass + + +# Tests for BatchSizeArg +def test_batch_size_arg_validate(): + # Valid inputs + assert BatchSizeArg.validate("10 samples") == (10, "samples") + assert BatchSizeArg.validate("20 words") == (20, "words") + assert BatchSizeArg.validate(15) == (15, "docs") + assert BatchSizeArg.validate("docs") == (None, "docs") + assert BatchSizeArg.validate("tokens") == (None, "tokens") + assert BatchSizeArg.validate("25") == (25, "docs") + + # Invalid inputs + with pytest.raises(Exception): + BatchSizeArg.validate("invalid input") + with pytest.raises(Exception): + BatchSizeArg.validate("10 invalid input") + with pytest.raises(Exception): + BatchSizeArg.validate("invalid input 10") + + +# Tests for batchify function +def test_batchify_simple(): + data = [1, 2, 3, 4, 5] + batches = list(batchify(data, batch_size=2)) + assert batches == [[1, 2], [3, 4], [5]] + + +def test_batchify_drop_last(): + data = [1, 2, 3, 4, 5] + batches = list(batchify(data, batch_size=2, drop_last=True)) + assert batches == [[1, 2], [3, 4]] + + +def test_batchify_sentinel_drop(): + data = [1, 2, MockStreamSentinel(), 3, 4] + batches = list(batchify(data, batch_size=2, sentinel_mode="drop")) + assert batches == [[1, 2], [3, 4]] + + +def test_batchify_sentinel_keep(): + sentinel = MockStreamSentinel() + data = [1, 2, sentinel, 3, 4] + batches = list(batchify(data, batch_size=2, sentinel_mode="keep")) + assert batches == [[1, 2, sentinel], [3, 4]] + + +def test_batchify_sentinel_split(): + sentinel = MockStreamSentinel() + data = [1, 2, sentinel, 3, 4] + batches = list(batchify(data, batch_size=2, sentinel_mode="split")) + assert batches == [[1, 2], sentinel, [3, 4]] + + +# Tests for batchify_by_length_sum +def test_batchify_by_length_sum_simple(): + data = ["a", "bb", "ccc", "dddd", "eeeee"] + batches = list(batchify_by_length_sum(data, batch_size=5)) + assert batches == [["a", "bb"], ["ccc"], ["dddd"], ["eeeee"]] + + +def test_batchify_by_length_sum_drop_last(): + data = ["a", "bb", "ccc", "dddd", "eeeee"] + batches = list(batchify_by_length_sum(data, batch_size=5, drop_last=True)) + assert batches == [["a", "bb"], ["ccc"], ["dddd"]] + + +# Tests for batchify_by_length_sum +def test_batchify_by_length_sum_split(): + sentinel = MockStreamSentinel() + data = ["aa", "bb", sentinel, "ccc", "dddd", "eeeee"] + batches = list(batchify_by_length_sum(data, batch_size=7, sentinel_mode="split")) + assert batches == [["aa", "bb"], sentinel, ["ccc", "dddd"], ["eeeee"]] + + +# Tests for batchify_by_length_sum +def test_batchify_by_length_sum_keep(): + sentinel = MockStreamSentinel() + data = ["aa", "bb", sentinel, "ccc", "dddd", "eeeee"] + batches = list(batchify_by_length_sum(data, batch_size=7, sentinel_mode="keep")) + assert batches == [["aa", "bb", sentinel, "ccc"], ["dddd"], ["eeeee"]] + + +# Tests for batchify_by_padded +def test_batchify_by_padded_simple(): + data = ["a", "bb", "ccc", "dddd"] + batches = list(batchify_by_padded(data, batch_size=6)) + assert batches == [["a", "bb"], ["ccc"], ["dddd"]] + + +def test_batchify_by_padded_drop_last(): + data = ["a", "bb", "ccc", "dddd"] + batches = list(batchify_by_padded(data, batch_size=6, drop_last=True)) + assert batches == [["a", "bb"], ["ccc"]] + + +def test_batchify_by_padded_sentinel_keep(): + sentinel = MockStreamSentinel() + data = ["a", sentinel, "bb", "ccc"] + batches = list(batchify_by_padded(data, batch_size=6, sentinel_mode="keep")) + assert batches == [["a", sentinel, "bb"], ["ccc"]] + + +def test_batchify_by_padded_sentinel_split(): + sentinel = MockStreamSentinel() + data = ["a", sentinel, "bb", "ccc"] + batches = list(batchify_by_padded(data, batch_size=5, sentinel_mode="split")) + assert batches == [["a"], sentinel, ["bb"], ["ccc"]] + + +# Tests for batchify_by_dataset +def test_batchify_by_dataset_simple(): + data = [ + "item1", + "item2", + DATASET_END_SENTINEL, + "item3", + DATASET_END_SENTINEL, + "item4", + "item5", + ] + batches = list(batchify_by_dataset(data)) + assert batches == [ + ["item1", "item2"], + DATASET_END_SENTINEL, + ["item3"], + DATASET_END_SENTINEL, + ["item4", "item5"], + ] + + +def test_batchify_by_dataset_sentinel_split(): + sentinel = MockStreamSentinel() + data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"] + batches = list(batchify_by_dataset(data, sentinel_mode="split")) + assert batches == [["item1"], sentinel, ["item2"], DATASET_END_SENTINEL, ["item3"]] + + +def test_batchify_by_dataset_sentinel_keep(): + sentinel = MockStreamSentinel() + data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"] + batches = list(batchify_by_dataset(data, sentinel_mode="keep")) + assert batches == [["item1", sentinel, "item2"], DATASET_END_SENTINEL, ["item3"]] + + +def test_batchify_by_dataset_sentinel_drop(): + sentinel = MockStreamSentinel() + data = ["item1", sentinel, "item2", DATASET_END_SENTINEL, "item3"] + batches = list(batchify_by_dataset(data, sentinel_mode="drop")) + assert batches == [["item1", "item2"], DATASET_END_SENTINEL, ["item3"]] + + +def test_batchify_by_dataset_drop_last(): + data = ["item1", "item2", DATASET_END_SENTINEL, "item3"] + batches = list(batchify_by_dataset(data, drop_last=True)) + assert batches == [["item1", "item2"], DATASET_END_SENTINEL] + + +# Tests for batchify_by_fragment +def test_batchify_by_fragment_simple(): + fragment_end_1 = FragmentEndSentinel("fragment1") + fragment_end_2 = FragmentEndSentinel("fragment2") + data = ["item1", "item2", fragment_end_1, "item3", fragment_end_2, "item4"] + batches = list(batchify_by_fragment(data)) + assert batches == [ + ["item1", "item2"], + fragment_end_1, + ["item3"], + fragment_end_2, + ["item4"], + ] + + +def test_batchify_by_fragment_sentinel_split(): + sentinel = MockStreamSentinel() + fragment_end = FragmentEndSentinel("fragment") + data = ["item1", sentinel, "item2", fragment_end] + batches = list(batchify_by_fragment(data, sentinel_mode="split")) + assert batches == [["item1"], sentinel, ["item2"], fragment_end] + + +def test_batchify_by_fragment_sentinel_keep(): + sentinel = MockStreamSentinel() + fragment_end = FragmentEndSentinel("fragment") + data = ["item1", sentinel, "item2", fragment_end] + batches = list(batchify_by_fragment(data, sentinel_mode="keep")) + assert batches == [["item1", sentinel, "item2"], fragment_end] + + +def test_batchify_by_fragment_sentinel_drop(): + sentinel = MockStreamSentinel() + fragment_end = FragmentEndSentinel("fragment") + data = ["item1", sentinel, "item2", fragment_end] + batches = list(batchify_by_fragment(data, sentinel_mode="drop")) + assert batches == [["item1", "item2"], fragment_end] + + +def test_batchify_by_fragment_drop_last(): + fragment_end = FragmentEndSentinel("fragment") + data = ["item1", "item2", fragment_end] + batches = list(batchify_by_fragment(data, sentinel_mode="split", drop_last=True)) + assert batches == [["item1", "item2"], fragment_end] + + +# Tests for stat_batchify +def test_stat_batchify_simple(): + data = [ + {"/stats/length": 2, "text": "aa"}, + {"/stats/length": 3, "text": "bbb"}, + {"/stats/length": 4, "text": "cccc"}, + {"/stats/length": 2, "text": "dd"}, + ] + batch_fn = stat_batchify("length") + batches = list(batch_fn(data, batch_size=5)) + assert batches == [ + [data[0], data[1]], # Total length: 5 + [data[2]], # Total length: 4 + [data[3]], # Total length: 2 + ] + + +def test_stat_batchify_invalid_key(): + data = [{"text": "aaa"}] + batch_fn = stat_batchify("length") + with pytest.raises(ValueError): + list(batch_fn(data, batch_size=5)) + + +def test_stat_batchify_sentinel_split(): + sentinel = MockStreamSentinel() + data = [ + {"/stats/length": 2, "text": "aa"}, + sentinel, + {"/stats/length": 3, "text": "bbb"}, + ] + batch_fn = stat_batchify("length") + batches = list(batch_fn(data, batch_size=5, sentinel_mode="split")) + assert batches == [ + [data[0]], + sentinel, + [data[2]], + ] + + +def test_stat_batchify_sentinel_keep(): + sentinel = MockStreamSentinel() + data = [ + {"/stats/length": 2, "text": "aa"}, + sentinel, + {"/stats/length": 4, "text": "bbbb"}, + ] + batch_fn = stat_batchify("length") + batches = list(batch_fn(data, batch_size=5, sentinel_mode="keep")) + assert batches == [ + [data[0], sentinel], + [data[2]], + ] + + +def test_stat_batchify_drop_last(): + data = [ + {"/stats/length": 2, "text": "aa"}, + {"/stats/length": 3, "text": "bbb"}, + {"/stats/length": 4, "text": "cccc"}, + ] + batch_fn = stat_batchify("length") + batches = list(batch_fn(data, batch_size=6, drop_last=True)) + assert batches == [ + [data[0], data[1]], # Total length: 5 + ] # Last batch is dropped because total length is 4 + + +# Additional tests to ensure full coverage +def test_batchify_empty_iterable(): + data = [] + batches = list(batchify(data, batch_size=2)) + assert batches == [] + + +def test_batchify_by_length_sum_empty_iterable(): + data = [] + batches = list(batchify_by_length_sum(data, batch_size=5)) + assert batches == [] + + +def test_batchify_by_padded_empty_iterable(): + data = [] + batches = list(batchify_by_padded(data, batch_size=6)) + assert batches == [] + + +def test_batchify_by_dataset_empty_iterable(): + data = [] + batches = list(batchify_by_dataset(data)) + assert batches == [] + + +def test_batchify_by_fragment_empty_iterable(): + data = [] + batches = list(batchify_by_fragment(data)) + assert batches == [] + + +def test_stat_batchify_empty_iterable(): + data = [] + batch_fn = stat_batchify("length") + batches = list(batch_fn(data, batch_size=5)) + assert batches == [] + + +def test_batchify_invalid_sentinel_mode(): + data = [1, 2, 3] + with pytest.raises(AssertionError): + list(batchify(data, batch_size=2, sentinel_mode="invalid_mode"))