a b/tests/data/test_stream.py
1
import pytest
2
3
import edsnlp
4
from edsnlp.utils.collections import ld_to_dl
5
6
try:
7
    import torch.nn
8
except ImportError:
9
    torch = None
10
11
12
def test_map_batches():
13
    items = [1, 2, 3, 4, 5]
14
    stream = edsnlp.data.from_iterable(items)
15
    stream = stream.map(lambda x: x + 1)  # 2, 3, 4, 5, 6
16
    stream = stream.map_batches(lambda x: [sum(x)])
17
    stream = stream.set_processing(
18
        num_cpu_workers=2,
19
        sort_chunks=False,
20
        batch_size=2,
21
    )
22
    res = list(stream)
23
    assert res == [6, 8, 6]  # 2+4, 3+5, 6
24
25
26
@pytest.mark.parametrize("num_cpu_workers", [1, 2])
27
def test_flat_iterable(num_cpu_workers):
28
    items = [1, 2, 3, 4]
29
    stream = edsnlp.data.from_iterable(items)
30
    stream = stream.set_processing(num_cpu_workers=num_cpu_workers)
31
    stream = stream.map(lambda x: [x] * x)
32
    stream = stream.flatten()
33
    res = list(stream.to_iterable(converter=lambda x: x))
34
    assert sorted(res) == [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
35
36
37
@pytest.mark.parametrize("num_gpu_workers", [0, 1, 2])
38
@pytest.mark.skipif(torch is None, reason="torch not installed")
39
def test_map_gpu(num_gpu_workers):
40
    import torch
41
42
    def prepare_batch(batch, device):
43
        return {"tensor": torch.tensor(batch).to(device)}
44
45
    def forward(batch):
46
        return {"outputs": batch["tensor"] * 2}
47
48
    items = range(15)
49
    stream = edsnlp.data.from_iterable(items)
50
    if num_gpu_workers == 0:
51
        # this is just to fuse tests, and test map_gpu
52
        # following a map_batches without specifying a batch size
53
        stream = stream.map_batches(lambda x: x)
54
    stream = stream.map_gpu(prepare_batch, forward)
55
    stream = stream.set_processing(
56
        num_gpu_workers=num_gpu_workers,
57
        gpu_worker_devices=["cpu"] * num_gpu_workers,
58
        sort_chunks=False,
59
        batch_size=2,
60
    )
61
62
    res = ld_to_dl(stream)
63
    res = torch.cat(res["outputs"])
64
    assert set(res.tolist()) == {i * 2 for i in range(15)}
65
66
67
# fmt: off
68
@pytest.mark.parametrize(
69
    "sort,num_cpu_workers,batch_kwargs,expected",
70
    [
71
        (False, 1, {"batch_size": 10, "batch_by": "words"}, [3, 1, 3, 1, 3, 1]),  # noqa: E501
72
        (False, 1, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 1, 2, 1, 1, 2, 1, 1]),  # noqa: E501
73
        (False, 1, {"batch_size": 10, "batch_by": "docs"}, [10, 2]),  # noqa: E501
74
        (False, 2, {"batch_size": 10, "batch_by": "words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]),  # noqa: E501
75
        (False, 2, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]),  # noqa: E501
76
        (False, 2, {"batch_size": 10, "batch_by": "docs"}, [6, 6]),  # noqa: E501
77
        (True, 2, {"batch_size": 10, "batch_by": "padded_words"}, [3, 3, 2, 1, 1, 1, 1]),  # noqa: E501
78
        (False, 2, {"batch_size": "10 words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]),  # noqa: E501
79
    ],
80
)
81
# fmt: on
82
def test_map_with_batching(sort, num_cpu_workers, batch_kwargs, expected):
83
    nlp = edsnlp.blank("eds")
84
    nlp.add_pipe(
85
        "eds.matcher",
86
        config={
87
            "terms": {
88
                "foo": ["This", "is", "a", "sentence", ".", "Short", "snippet", "too"],
89
            }
90
        },
91
        name="matcher",
92
    )
93
    samples = [
94
        "This is a sentence.",
95
        "Short snippet",
96
        "Short snippet too",
97
        "This is a very very long sentence that will make more than 10 words",
98
    ] * 3
99
    stream = edsnlp.data.from_iterable(samples)
100
    if sort:
101
        stream = stream.map_batches(lambda x: sorted(x, key=len), batch_size=1000)
102
    stream = stream.map_pipeline(nlp)
103
    stream = stream.map_batches(len)
104
    stream = stream.set_processing(
105
        num_cpu_workers=num_cpu_workers,
106
        **batch_kwargs,
107
        chunk_size=1000,  # deprecated
108
        split_into_batches_after="matcher",
109
        show_progress=True,
110
    )
111
    assert list(stream) == expected
112
113
114
def test_repr(frozen_ml_nlp, tmp_path):
115
    items = ["ceci est un test", "ceci est un autre test"]
116
    stream = (
117
        edsnlp.data.from_iterable(items, converter=frozen_ml_nlp.make_doc)
118
        .map(lambda x: x)
119
        .map_pipeline(frozen_ml_nlp, batch_size=2)
120
        .map_batches(lambda b: sorted(b, key=len))
121
        .set_processing(num_cpu_workers=2)
122
        .write_json(tmp_path / "out_test.jsonl", lines=True, execute=False)
123
    )
124
    assert "Stream" in repr(stream)
125
126
127
@pytest.mark.parametrize("shuffle_reader", [True, False])
128
def test_shuffle_before_generator(shuffle_reader):
129
    def gen_fn(x):
130
        yield x
131
        yield x
132
133
    items = [1, 2, 3, 4, 5]
134
    stream = edsnlp.data.from_iterable(items)
135
    stream = stream.map(lambda x: x)
136
    stream = stream.shuffle(seed=42, shuffle_reader=shuffle_reader)
137
    stream = stream.map(gen_fn)
138
    assert stream.reader.shuffle == ("dataset" if shuffle_reader else False)
139
    assert len(stream.ops) == (2 if shuffle_reader else 5)
140
    res = list(stream)
141
    assert res == [4, 4, 2, 2, 3, 3, 5, 5, 1, 1]
142
143
144
def test_shuffle_after_generator():
145
    def gen_fn(x):
146
        yield x
147
        yield x
148
149
    items = [1, 2, 3, 4, 5]
150
    stream = edsnlp.data.from_iterable(items)
151
    stream = stream.map(lambda x: x)
152
    stream = stream.map(gen_fn)
153
    stream = stream.shuffle(seed=43)
154
    assert stream.reader.shuffle == "dataset"
155
    assert len(stream.ops) == 5
156
    res = list(stream)
157
    assert res == [1, 2, 4, 3, 1, 3, 5, 5, 4, 2]
158
159
160
def test_shuffle_frozen_ml_pipeline(run_in_test_dir, frozen_ml_nlp):
161
    stream = edsnlp.data.read_parquet("../resources/docs.parquet", converter="omop")
162
    stream = stream.map_pipeline(frozen_ml_nlp, batch_size=2)
163
    assert len(stream.ops) == 7
164
    stream = stream.shuffle(batch_by="fragment")
165
    assert len(stream.ops) == 7
166
    assert stream.reader.shuffle == "fragment"
167
168
169
def test_unknown_shuffle():
170
    items = [1, 2, 3, 4, 5]
171
    stream = edsnlp.data.from_iterable(items)
172
    stream = stream.map(lambda x: x)
173
    with pytest.raises(ValueError):
174
        stream.shuffle("unknown")
175
176
177
def test_int_shuffle():
178
    items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
179
    stream = edsnlp.data.from_iterable(items)
180
    stream = stream.map(lambda x: x)
181
    stream = stream.shuffle("2 docs", seed=42)
182
    assert list(stream) == [2, 1, 4, 3, 5, 6, 8, 7, 10, 9]
183
184
185
def test_parallel_preprocess_stop(run_in_test_dir, frozen_ml_nlp):
186
    nlp = frozen_ml_nlp
187
    stream = edsnlp.data.read_parquet(
188
        "../resources/docs.parquet",
189
        "omop",
190
        loop=True,
191
    )
192
    stream = stream.map(edsnlp.pipes.split(regex="\n+"))
193
    stream = stream.map(nlp.preprocess, kwargs=dict(supervision=True))
194
    stream = stream.batchify("128 words")
195
    stream = stream.map(nlp.collate)
196
    stream = stream.set_processing(num_cpu_workers=1, process_start_method="spawn")
197
198
    it = iter(stream)
199
    total = 0
200
    for _ in zip(it, range(10)):
201
        total += 1
202
203
    assert total == 10
204
    del it