Diff of /tests/test_pipeline.py [000000] .. [cad161]

Switch to unified view

a b/tests/test_pipeline.py
1
import os
2
import subprocess
3
import sys
4
from io import BytesIO
5
6
import pytest
7
from confit import Config
8
from confit.errors import ConfitValidationError
9
from confit.registry import validate_arguments
10
from spacy.tokens import Doc
11
12
import edsnlp
13
import edsnlp.pipes as eds
14
from edsnlp import Pipeline, registry
15
from edsnlp.core.registries import CurriedFactory
16
from edsnlp.pipes.base import BaseComponent
17
18
try:
19
    import torch.nn
20
except ImportError:
21
    torch = None
22
23
24
class CustomClass:
25
    pass
26
27
    def __call__(self, doc: Doc) -> Doc:
28
        return doc
29
30
31
def test_add_pipe_factory():
32
    model = edsnlp.blank("eds")
33
    model.add_pipe("eds.normalizer", name="normalizer")
34
    assert "normalizer" in model.pipe_names
35
    assert model.has_pipe("normalizer")
36
37
    model.add_pipe("eds.sentences", name="sentences")
38
    assert "sentences" in model.pipe_names
39
    assert model.has_pipe("sentences")
40
41
    with pytest.raises(ValueError):
42
        model.get_pipe("missing-pipe")
43
44
45
def test_add_pipe_component():
46
    model = edsnlp.blank("eds")
47
    model.add_pipe(eds.normalizer(nlp=model), name="normalizer")
48
    assert "normalizer" in model.pipe_names
49
    assert model.has_pipe("normalizer")
50
    assert model.pipes.normalizer is model.get_pipe("normalizer")
51
52
    model.add_pipe(eds.sentences(nlp=model), name="sentences")
53
    assert "sentences" in model.pipe_names
54
    assert model.has_pipe("sentences")
55
    assert model.pipes.sentences is model.get_pipe("sentences")
56
57
    with pytest.raises(ValueError):
58
        model.add_pipe(
59
            eds.sentences(nlp=model, name="sentences"),
60
            config={"punct_chars": ".?!"},
61
        )
62
63
    with pytest.raises(ValueError):
64
        model.add_pipe(CustomClass())
65
66
67
def test_sequence(frozen_ml_nlp: Pipeline):
68
    assert len(frozen_ml_nlp.pipeline) == 3
69
    assert list(frozen_ml_nlp.pipeline) == [
70
        ("sentences", frozen_ml_nlp.get_pipe("sentences")),
71
        ("transformer", frozen_ml_nlp.get_pipe("transformer")),
72
        ("ner", frozen_ml_nlp.get_pipe("ner")),
73
    ]
74
    assert list(frozen_ml_nlp.torch_components()) == [
75
        ("transformer", frozen_ml_nlp.get_pipe("transformer")),
76
        ("ner", frozen_ml_nlp.get_pipe("ner")),
77
    ]
78
79
80
def test_disk_serialization(tmp_path, ml_nlp):
81
    nlp = ml_nlp
82
83
    assert nlp.get_pipe("transformer").stride == 96
84
    ner = nlp.get_pipe("ner")
85
    ner.update_labels(["PERSON", "GIFT"])
86
87
    os.makedirs(tmp_path / "model", exist_ok=True)
88
    # by default, vocab is excluded
89
    nlp.to_disk(tmp_path / "model", exclude=set())
90
91
    assert (tmp_path / "model" / "config.cfg").exists()
92
    assert (tmp_path / "model" / "ner" / "parameters.safetensors").exists()
93
    assert (tmp_path / "model" / "transformer" / "parameters.safetensors").exists()
94
    # fmt: off
95
    assert (
96
            (tmp_path / "model" / "transformer" / "pytorch_model.bin").exists() or
97
            (tmp_path / "model" / "transformer" / "model.safetensors").exists()
98
    )
99
    # fmt: on
100
101
    assert (tmp_path / "model" / "config.cfg").read_text() == (
102
        config_str.replace("components = ${components}\n", "").replace(
103
            "prajjwal1/bert-tiny", "./transformer"
104
        )
105
    )
106
107
    nlp = edsnlp.load(
108
        tmp_path / "model",
109
        overrides={"components": {"transformer": {"stride": 64}}},
110
    )
111
    assert nlp.get_pipe("ner").labels == ["PERSON", "GIFT"]
112
    assert nlp.get_pipe("transformer").stride == 64
113
114
115
config_str = """\
116
[nlp]
117
lang = "eds"
118
pipeline = ["sentences", "transformer", "ner"]
119
components = ${components}
120
121
[nlp.tokenizer]
122
@tokenizers = "eds.tokenizer"
123
124
[components]
125
126
[components.sentences]
127
@factory = "eds.sentences"
128
129
[components.transformer]
130
@factory = "eds.transformer"
131
model = "prajjwal1/bert-tiny"
132
window = 128
133
stride = 96
134
135
[components.ner]
136
@factory = "eds.ner_crf"
137
embedding = ${components.transformer}
138
mode = "independent"
139
target_span_getter = ["ents", "ner-preds"]
140
labels = ["PERSON", "GIFT"]
141
infer_span_setter = false
142
window = 40
143
stride = 20
144
145
[components.ner.span_setter]
146
ents = true
147
148
"""
149
150
151
@pytest.mark.skipif(torch is None, reason="torch not installed")
152
def test_validate_config():
153
    @validate_arguments
154
    def function(model: Pipeline):
155
        assert len(model.pipe_names) == 3
156
157
    function(Config.from_str(config_str).resolve(registry=registry)["nlp"])
158
159
160
def test_torch_module(frozen_ml_nlp: Pipeline):
161
    with frozen_ml_nlp.train(True):
162
        for name, component in frozen_ml_nlp.torch_components():
163
            assert component.training is True
164
165
    with frozen_ml_nlp.train(False):
166
        for name, component in frozen_ml_nlp.torch_components():
167
            assert component.training is False
168
169
    frozen_ml_nlp.to("cpu")
170
171
172
def test_cache(frozen_ml_nlp: Pipeline):
173
    from edsnlp.core.torch_component import _caches
174
175
    text = "Ceci est un exemple"
176
    frozen_ml_nlp(text)
177
178
    doc = frozen_ml_nlp.make_doc(text)
179
    with frozen_ml_nlp.cache():
180
        for name, pipe in frozen_ml_nlp.pipeline:
181
            # This is a hack to get around the ambiguity
182
            # between the __call__ method of Pytorch modules
183
            # and the __call__ methods of spacy components
184
            if hasattr(pipe, "batch_process"):
185
                doc = next(iter(pipe.batch_process([doc])))
186
            else:
187
                doc = pipe(doc)
188
        trf_forward_cache_entries = [
189
            key
190
            for key in _caches["default"]
191
            if isinstance(key, tuple) and key[0] == "forward"
192
        ]
193
        assert len(trf_forward_cache_entries) == 2
194
195
    assert len(_caches) == 0
196
197
198
def test_select_pipes(frozen_ml_nlp: Pipeline):
199
    text = "Ceci est un exemple"
200
    with frozen_ml_nlp.select_pipes(enable=["transformer", "ner"]):
201
        assert len(frozen_ml_nlp.disabled) == 1
202
        assert not frozen_ml_nlp(text).has_annotation("SENT_START")
203
    assert len(frozen_ml_nlp.disabled) == 0
204
205
206
@pytest.mark.skip(reason="Deprecated behavior")
207
def test_different_names():
208
    nlp = edsnlp.blank("eds")
209
210
    extractor = eds.sentences(nlp=nlp, name="custom_name")
211
212
    with pytest.raises(ValueError) as exc_info:
213
        nlp.add_pipe(extractor, name="sentences")
214
215
    assert (
216
        "The provided name 'sentences' does not "
217
        "match the name of the component 'custom_name'."
218
    ) in str(exc_info.value)
219
220
221
@pytest.mark.skipif(torch is None, reason="torch not installed")
222
def test_load_config(run_in_test_dir):
223
    nlp = edsnlp.load("training/qlf_config.yml")
224
    assert nlp.pipe_names == [
225
        "normalizer",
226
        "sentencizer",
227
        "covid",
228
        "qualifier",
229
    ]
230
231
232
fail_config = """
233
[nlp]
234
lang = "eds"
235
pipeline = ["transformer", "ner"]
236
237
[nlp.tokenizer]
238
@tokenizers = "eds.tokenizer"
239
240
[components]
241
242
[components.transformer]
243
@factory = "eds.transformer"
244
model = "prajjwal1/bert-tiny"
245
window = 128
246
stride = 96
247
248
[components.ner]
249
@factory = "eds.ner_crf"
250
embedding = ${components.transformer}
251
mode = "error-mode"
252
span_setter = "ents"
253
"""
254
255
256
@pytest.mark.skipif(torch is None, reason="torch not installed")
257
def test_config_validation_error():
258
    with pytest.raises(ConfitValidationError) as e:
259
        Pipeline.from_config(Config.from_str(fail_config))
260
261
    assert "1 validation error for" in str(e.value)
262
    assert "got 'error-mode'" in str(e.value)
263
264
265
@edsnlp.registry.factory.register("test_wrapper", spacy_compatible=False)
266
class WrapperComponent:
267
    def __init__(self, *, copy_list, copy_dict, sub):
268
        pass
269
270
271
fail_config_sub = """
272
nlp:
273
    lang: "eds"
274
    components:
275
        wrapper:
276
            "@factory": "test_wrapper"
277
278
            copy_list:
279
                - ${nlp.components.wrapper.sub}
280
281
            copy_dict:
282
                key: ${nlp.components.wrapper.sub}
283
284
            sub:
285
                "@factory": "eds.matcher"
286
                terms: 100.0  # clearly wrong
287
288
        matcher_copy: ${nlp.components.wrapper.sub}
289
"""
290
291
292
def test_config_sub_validation_error():
293
    with pytest.raises(ConfitValidationError):
294
        Pipeline.from_config(Config.from_yaml_str(fail_config_sub))
295
296
    fix = {"nlp": {"components": {"wrapper": {"sub": {"terms": {"pattern": ["ok"]}}}}}}
297
    Pipeline.from_config(Config.from_yaml_str(fail_config_sub).merge(fix))
298
299
300
def test_add_pipe_validation_error():
301
    model = edsnlp.blank("eds")
302
    with pytest.raises(ConfitValidationError) as e:
303
        model.add_pipe("eds.covid", name="extractor", config={"foo": "bar"})
304
305
    assert str(e.value) == (
306
        "1 validation error for "
307
        "edsnlp.pipes.ner.covid.factory.create_component()\n"
308
        "-> extractor.foo\n"
309
        "   unexpected keyword argument"
310
    )
311
312
313
def test_spacy_component():
314
    nlp = edsnlp.blank("fr")
315
    nlp.add_pipe("sentencizer")
316
317
318
def test_rule_based_pipeline():
319
    nlp = edsnlp.blank("eds")
320
    nlp.add_pipe("eds.normalizer")
321
    nlp.add_pipe("eds.covid")
322
323
    assert nlp.pipe_names == ["normalizer", "covid"]
324
    assert nlp.get_pipe("normalizer") == nlp.pipeline[0][1]
325
    assert nlp.has_pipe("covid")
326
327
    with pytest.raises(ValueError) as exc_info:
328
        nlp.get_pipe("unknown")
329
330
    assert str(exc_info.value) == "Pipe 'unknown' not found in pipeline."
331
332
    doc = nlp.make_doc("Mon patient a le covid")
333
334
    new_doc = nlp(doc)
335
336
    assert len(doc.ents) == 1
337
    assert new_doc is doc
338
339
    assert nlp.get_pipe_meta("covid").assigns == ["doc.ents", "doc.spans"]
340
341
342
def test_torch_save(ml_nlp):
343
    import torch
344
345
    ml_nlp.get_pipe("ner").update_labels(["LOC", "PER"])
346
    buffer = BytesIO()
347
    torch.save(ml_nlp, buffer)
348
    buffer.seek(0)
349
    nlp = torch.load(buffer, weights_only=False)
350
    assert nlp.get_pipe("ner").labels == ["LOC", "PER"]
351
    assert len(list(nlp("Une phrase. Deux phrases.").sents)) == 2
352
353
354
def test_parameters(frozen_ml_nlp):
355
    assert len(list(frozen_ml_nlp.parameters())) == 42
356
357
358
def test_missing_factory(nlp):
359
    with pytest.raises(ValueError) as exc_info:
360
        nlp.add_pipe("__test_missing_pipe__")
361
362
    assert "__test_missing_pipe__" in str(exc_info.value)
363
364
365
@edsnlp.registry.factory("custom-curry-test")
366
class CustomComponent(BaseComponent):
367
    def __init__(self, nlp, name):
368
        self.nlp = nlp
369
370
    def __call__(self, doc):
371
        return doc
372
373
374
def test_curried_nlp_pipe():
375
    nlp = edsnlp.blank("eds")
376
    nlp.add_pipe(eds.sentences(name="my-sentences"))
377
    nlp.add_pipe(eds.normalizer())
378
    nlp.add_pipe(eds.sections(), name="sections")
379
    pipe = CustomComponent()
380
381
    assert isinstance(pipe, CurriedFactory)
382
    err = (
383
        f"This component CurriedFactory({pipe.factory}) has not been instantiated "
384
        f"yet, likely because it was missing an `nlp` pipeline argument. You should "
385
        f"either:\n"
386
        "- add it to a pipeline: `pipe = nlp.add_pipe(pipe)`\n"
387
        "- or fill its `nlp` argument: `pipe = factory(nlp=nlp, ...)`"
388
    )
389
    with pytest.raises(TypeError) as exc_info:
390
        pipe("Demo texte")
391
    assert str(exc_info.value) == err
392
393
    with pytest.raises(TypeError) as exc_info:
394
        pipe.forward("Demo texte")
395
    assert str(exc_info.value) == err
396
397
    nlp.add_pipe(pipe, name="custom")
398
399
    assert nlp.pipes.custom.nlp is nlp
400
401
    assert nlp.pipe_names == ["my-sentences", "normalizer", "sections", "custom"]
402
403
404
@pytest.mark.skipif(
405
    sys.version_info < (3, 8),
406
    reason="Can't run on GH CI with Python 3.7",
407
)
408
@pytest.mark.skipif(torch is None, reason="torch not installed")
409
def test_huggingface():
410
    nlp = edsnlp.load(
411
        "AP-HP/dummy-ner",
412
        auto_update=True,
413
        install_dependencies=True,
414
    )
415
    doc = nlp("On lui prescrit du paracetamol à 500mg.")
416
    assert doc.ents[0].text == "paracetamol"
417
    assert doc.ents[1].text == "500mg"
418
419
    # Try loading it twice for coverage
420
    edsnlp.load(
421
        "AP-HP/dummy-ner",
422
        auto_update=True,
423
        install_dependencies=True,
424
    )
425
426
    subprocess.run(["pip", "uninstall", "dummy-pip-package", "-y"], check=True)
427
428
429
@pytest.mark.skipif(
430
    sys.version_info < (3, 8),
431
    reason="Can't run on GH CI with Python 3.7",
432
)
433
def test_missing_huggingface():
434
    with pytest.raises(ValueError) as exc_info:
435
        edsnlp.load(
436
            "AP-HP/does-not-exist",
437
            auto_update=True,
438
        )
439
440
    assert "The load function expects either :" in str(exc_info.value)
441
442
443
def test_repr(frozen_ml_nlp):
444
    with frozen_ml_nlp.select_pipes(disable=["sentences"]):
445
        assert (
446
            repr(frozen_ml_nlp)
447
            == """\
448
Pipeline(lang=eds, pipes={
449
  "sentences": [disabled] eds.sentences,
450
  "transformer": eds.transformer,
451
  "ner": eds.ner_crf
452
})"""
453
        )
454
455
456
@edsnlp.registry.factory.register("test_nlp_less", spacy_compatible=False)
457
class NlpLessComponent:
458
    def __init__(self, nlp=None, name: str = "nlp_less", *, value: int):
459
        self.value = value
460
        self.name = name
461
462
    def __call__(self, doc):
463
        return doc
464
465
466
def test_nlp_less_component():
467
    component = NlpLessComponent(value=42)
468
    assert component.value == 42
469
470
    config = """
471
[component]
472
@factory = "test_nlp_less"
473
value = 42
474
"""
475
    component = Config.from_str(config).resolve(registry=registry)["component"]
476
    assert component.value == 42