|
a |
|
b/tests/test_pipeline.py |
|
|
1 |
import os |
|
|
2 |
import subprocess |
|
|
3 |
import sys |
|
|
4 |
from io import BytesIO |
|
|
5 |
|
|
|
6 |
import pytest |
|
|
7 |
from confit import Config |
|
|
8 |
from confit.errors import ConfitValidationError |
|
|
9 |
from confit.registry import validate_arguments |
|
|
10 |
from spacy.tokens import Doc |
|
|
11 |
|
|
|
12 |
import edsnlp |
|
|
13 |
import edsnlp.pipes as eds |
|
|
14 |
from edsnlp import Pipeline, registry |
|
|
15 |
from edsnlp.core.registries import CurriedFactory |
|
|
16 |
from edsnlp.pipes.base import BaseComponent |
|
|
17 |
|
|
|
18 |
try: |
|
|
19 |
import torch.nn |
|
|
20 |
except ImportError: |
|
|
21 |
torch = None |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
class CustomClass: |
|
|
25 |
pass |
|
|
26 |
|
|
|
27 |
def __call__(self, doc: Doc) -> Doc: |
|
|
28 |
return doc |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def test_add_pipe_factory(): |
|
|
32 |
model = edsnlp.blank("eds") |
|
|
33 |
model.add_pipe("eds.normalizer", name="normalizer") |
|
|
34 |
assert "normalizer" in model.pipe_names |
|
|
35 |
assert model.has_pipe("normalizer") |
|
|
36 |
|
|
|
37 |
model.add_pipe("eds.sentences", name="sentences") |
|
|
38 |
assert "sentences" in model.pipe_names |
|
|
39 |
assert model.has_pipe("sentences") |
|
|
40 |
|
|
|
41 |
with pytest.raises(ValueError): |
|
|
42 |
model.get_pipe("missing-pipe") |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def test_add_pipe_component(): |
|
|
46 |
model = edsnlp.blank("eds") |
|
|
47 |
model.add_pipe(eds.normalizer(nlp=model), name="normalizer") |
|
|
48 |
assert "normalizer" in model.pipe_names |
|
|
49 |
assert model.has_pipe("normalizer") |
|
|
50 |
assert model.pipes.normalizer is model.get_pipe("normalizer") |
|
|
51 |
|
|
|
52 |
model.add_pipe(eds.sentences(nlp=model), name="sentences") |
|
|
53 |
assert "sentences" in model.pipe_names |
|
|
54 |
assert model.has_pipe("sentences") |
|
|
55 |
assert model.pipes.sentences is model.get_pipe("sentences") |
|
|
56 |
|
|
|
57 |
with pytest.raises(ValueError): |
|
|
58 |
model.add_pipe( |
|
|
59 |
eds.sentences(nlp=model, name="sentences"), |
|
|
60 |
config={"punct_chars": ".?!"}, |
|
|
61 |
) |
|
|
62 |
|
|
|
63 |
with pytest.raises(ValueError): |
|
|
64 |
model.add_pipe(CustomClass()) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def test_sequence(frozen_ml_nlp: Pipeline): |
|
|
68 |
assert len(frozen_ml_nlp.pipeline) == 3 |
|
|
69 |
assert list(frozen_ml_nlp.pipeline) == [ |
|
|
70 |
("sentences", frozen_ml_nlp.get_pipe("sentences")), |
|
|
71 |
("transformer", frozen_ml_nlp.get_pipe("transformer")), |
|
|
72 |
("ner", frozen_ml_nlp.get_pipe("ner")), |
|
|
73 |
] |
|
|
74 |
assert list(frozen_ml_nlp.torch_components()) == [ |
|
|
75 |
("transformer", frozen_ml_nlp.get_pipe("transformer")), |
|
|
76 |
("ner", frozen_ml_nlp.get_pipe("ner")), |
|
|
77 |
] |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
def test_disk_serialization(tmp_path, ml_nlp): |
|
|
81 |
nlp = ml_nlp |
|
|
82 |
|
|
|
83 |
assert nlp.get_pipe("transformer").stride == 96 |
|
|
84 |
ner = nlp.get_pipe("ner") |
|
|
85 |
ner.update_labels(["PERSON", "GIFT"]) |
|
|
86 |
|
|
|
87 |
os.makedirs(tmp_path / "model", exist_ok=True) |
|
|
88 |
# by default, vocab is excluded |
|
|
89 |
nlp.to_disk(tmp_path / "model", exclude=set()) |
|
|
90 |
|
|
|
91 |
assert (tmp_path / "model" / "config.cfg").exists() |
|
|
92 |
assert (tmp_path / "model" / "ner" / "parameters.safetensors").exists() |
|
|
93 |
assert (tmp_path / "model" / "transformer" / "parameters.safetensors").exists() |
|
|
94 |
# fmt: off |
|
|
95 |
assert ( |
|
|
96 |
(tmp_path / "model" / "transformer" / "pytorch_model.bin").exists() or |
|
|
97 |
(tmp_path / "model" / "transformer" / "model.safetensors").exists() |
|
|
98 |
) |
|
|
99 |
# fmt: on |
|
|
100 |
|
|
|
101 |
assert (tmp_path / "model" / "config.cfg").read_text() == ( |
|
|
102 |
config_str.replace("components = ${components}\n", "").replace( |
|
|
103 |
"prajjwal1/bert-tiny", "./transformer" |
|
|
104 |
) |
|
|
105 |
) |
|
|
106 |
|
|
|
107 |
nlp = edsnlp.load( |
|
|
108 |
tmp_path / "model", |
|
|
109 |
overrides={"components": {"transformer": {"stride": 64}}}, |
|
|
110 |
) |
|
|
111 |
assert nlp.get_pipe("ner").labels == ["PERSON", "GIFT"] |
|
|
112 |
assert nlp.get_pipe("transformer").stride == 64 |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
config_str = """\ |
|
|
116 |
[nlp] |
|
|
117 |
lang = "eds" |
|
|
118 |
pipeline = ["sentences", "transformer", "ner"] |
|
|
119 |
components = ${components} |
|
|
120 |
|
|
|
121 |
[nlp.tokenizer] |
|
|
122 |
@tokenizers = "eds.tokenizer" |
|
|
123 |
|
|
|
124 |
[components] |
|
|
125 |
|
|
|
126 |
[components.sentences] |
|
|
127 |
@factory = "eds.sentences" |
|
|
128 |
|
|
|
129 |
[components.transformer] |
|
|
130 |
@factory = "eds.transformer" |
|
|
131 |
model = "prajjwal1/bert-tiny" |
|
|
132 |
window = 128 |
|
|
133 |
stride = 96 |
|
|
134 |
|
|
|
135 |
[components.ner] |
|
|
136 |
@factory = "eds.ner_crf" |
|
|
137 |
embedding = ${components.transformer} |
|
|
138 |
mode = "independent" |
|
|
139 |
target_span_getter = ["ents", "ner-preds"] |
|
|
140 |
labels = ["PERSON", "GIFT"] |
|
|
141 |
infer_span_setter = false |
|
|
142 |
window = 40 |
|
|
143 |
stride = 20 |
|
|
144 |
|
|
|
145 |
[components.ner.span_setter] |
|
|
146 |
ents = true |
|
|
147 |
|
|
|
148 |
""" |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
@pytest.mark.skipif(torch is None, reason="torch not installed") |
|
|
152 |
def test_validate_config(): |
|
|
153 |
@validate_arguments |
|
|
154 |
def function(model: Pipeline): |
|
|
155 |
assert len(model.pipe_names) == 3 |
|
|
156 |
|
|
|
157 |
function(Config.from_str(config_str).resolve(registry=registry)["nlp"]) |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
def test_torch_module(frozen_ml_nlp: Pipeline): |
|
|
161 |
with frozen_ml_nlp.train(True): |
|
|
162 |
for name, component in frozen_ml_nlp.torch_components(): |
|
|
163 |
assert component.training is True |
|
|
164 |
|
|
|
165 |
with frozen_ml_nlp.train(False): |
|
|
166 |
for name, component in frozen_ml_nlp.torch_components(): |
|
|
167 |
assert component.training is False |
|
|
168 |
|
|
|
169 |
frozen_ml_nlp.to("cpu") |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
def test_cache(frozen_ml_nlp: Pipeline): |
|
|
173 |
from edsnlp.core.torch_component import _caches |
|
|
174 |
|
|
|
175 |
text = "Ceci est un exemple" |
|
|
176 |
frozen_ml_nlp(text) |
|
|
177 |
|
|
|
178 |
doc = frozen_ml_nlp.make_doc(text) |
|
|
179 |
with frozen_ml_nlp.cache(): |
|
|
180 |
for name, pipe in frozen_ml_nlp.pipeline: |
|
|
181 |
# This is a hack to get around the ambiguity |
|
|
182 |
# between the __call__ method of Pytorch modules |
|
|
183 |
# and the __call__ methods of spacy components |
|
|
184 |
if hasattr(pipe, "batch_process"): |
|
|
185 |
doc = next(iter(pipe.batch_process([doc]))) |
|
|
186 |
else: |
|
|
187 |
doc = pipe(doc) |
|
|
188 |
trf_forward_cache_entries = [ |
|
|
189 |
key |
|
|
190 |
for key in _caches["default"] |
|
|
191 |
if isinstance(key, tuple) and key[0] == "forward" |
|
|
192 |
] |
|
|
193 |
assert len(trf_forward_cache_entries) == 2 |
|
|
194 |
|
|
|
195 |
assert len(_caches) == 0 |
|
|
196 |
|
|
|
197 |
|
|
|
198 |
def test_select_pipes(frozen_ml_nlp: Pipeline): |
|
|
199 |
text = "Ceci est un exemple" |
|
|
200 |
with frozen_ml_nlp.select_pipes(enable=["transformer", "ner"]): |
|
|
201 |
assert len(frozen_ml_nlp.disabled) == 1 |
|
|
202 |
assert not frozen_ml_nlp(text).has_annotation("SENT_START") |
|
|
203 |
assert len(frozen_ml_nlp.disabled) == 0 |
|
|
204 |
|
|
|
205 |
|
|
|
206 |
@pytest.mark.skip(reason="Deprecated behavior") |
|
|
207 |
def test_different_names(): |
|
|
208 |
nlp = edsnlp.blank("eds") |
|
|
209 |
|
|
|
210 |
extractor = eds.sentences(nlp=nlp, name="custom_name") |
|
|
211 |
|
|
|
212 |
with pytest.raises(ValueError) as exc_info: |
|
|
213 |
nlp.add_pipe(extractor, name="sentences") |
|
|
214 |
|
|
|
215 |
assert ( |
|
|
216 |
"The provided name 'sentences' does not " |
|
|
217 |
"match the name of the component 'custom_name'." |
|
|
218 |
) in str(exc_info.value) |
|
|
219 |
|
|
|
220 |
|
|
|
221 |
@pytest.mark.skipif(torch is None, reason="torch not installed") |
|
|
222 |
def test_load_config(run_in_test_dir): |
|
|
223 |
nlp = edsnlp.load("training/qlf_config.yml") |
|
|
224 |
assert nlp.pipe_names == [ |
|
|
225 |
"normalizer", |
|
|
226 |
"sentencizer", |
|
|
227 |
"covid", |
|
|
228 |
"qualifier", |
|
|
229 |
] |
|
|
230 |
|
|
|
231 |
|
|
|
232 |
fail_config = """ |
|
|
233 |
[nlp] |
|
|
234 |
lang = "eds" |
|
|
235 |
pipeline = ["transformer", "ner"] |
|
|
236 |
|
|
|
237 |
[nlp.tokenizer] |
|
|
238 |
@tokenizers = "eds.tokenizer" |
|
|
239 |
|
|
|
240 |
[components] |
|
|
241 |
|
|
|
242 |
[components.transformer] |
|
|
243 |
@factory = "eds.transformer" |
|
|
244 |
model = "prajjwal1/bert-tiny" |
|
|
245 |
window = 128 |
|
|
246 |
stride = 96 |
|
|
247 |
|
|
|
248 |
[components.ner] |
|
|
249 |
@factory = "eds.ner_crf" |
|
|
250 |
embedding = ${components.transformer} |
|
|
251 |
mode = "error-mode" |
|
|
252 |
span_setter = "ents" |
|
|
253 |
""" |
|
|
254 |
|
|
|
255 |
|
|
|
256 |
@pytest.mark.skipif(torch is None, reason="torch not installed") |
|
|
257 |
def test_config_validation_error(): |
|
|
258 |
with pytest.raises(ConfitValidationError) as e: |
|
|
259 |
Pipeline.from_config(Config.from_str(fail_config)) |
|
|
260 |
|
|
|
261 |
assert "1 validation error for" in str(e.value) |
|
|
262 |
assert "got 'error-mode'" in str(e.value) |
|
|
263 |
|
|
|
264 |
|
|
|
265 |
@edsnlp.registry.factory.register("test_wrapper", spacy_compatible=False) |
|
|
266 |
class WrapperComponent: |
|
|
267 |
def __init__(self, *, copy_list, copy_dict, sub): |
|
|
268 |
pass |
|
|
269 |
|
|
|
270 |
|
|
|
271 |
fail_config_sub = """ |
|
|
272 |
nlp: |
|
|
273 |
lang: "eds" |
|
|
274 |
components: |
|
|
275 |
wrapper: |
|
|
276 |
"@factory": "test_wrapper" |
|
|
277 |
|
|
|
278 |
copy_list: |
|
|
279 |
- ${nlp.components.wrapper.sub} |
|
|
280 |
|
|
|
281 |
copy_dict: |
|
|
282 |
key: ${nlp.components.wrapper.sub} |
|
|
283 |
|
|
|
284 |
sub: |
|
|
285 |
"@factory": "eds.matcher" |
|
|
286 |
terms: 100.0 # clearly wrong |
|
|
287 |
|
|
|
288 |
matcher_copy: ${nlp.components.wrapper.sub} |
|
|
289 |
""" |
|
|
290 |
|
|
|
291 |
|
|
|
292 |
def test_config_sub_validation_error(): |
|
|
293 |
with pytest.raises(ConfitValidationError): |
|
|
294 |
Pipeline.from_config(Config.from_yaml_str(fail_config_sub)) |
|
|
295 |
|
|
|
296 |
fix = {"nlp": {"components": {"wrapper": {"sub": {"terms": {"pattern": ["ok"]}}}}}} |
|
|
297 |
Pipeline.from_config(Config.from_yaml_str(fail_config_sub).merge(fix)) |
|
|
298 |
|
|
|
299 |
|
|
|
300 |
def test_add_pipe_validation_error(): |
|
|
301 |
model = edsnlp.blank("eds") |
|
|
302 |
with pytest.raises(ConfitValidationError) as e: |
|
|
303 |
model.add_pipe("eds.covid", name="extractor", config={"foo": "bar"}) |
|
|
304 |
|
|
|
305 |
assert str(e.value) == ( |
|
|
306 |
"1 validation error for " |
|
|
307 |
"edsnlp.pipes.ner.covid.factory.create_component()\n" |
|
|
308 |
"-> extractor.foo\n" |
|
|
309 |
" unexpected keyword argument" |
|
|
310 |
) |
|
|
311 |
|
|
|
312 |
|
|
|
313 |
def test_spacy_component(): |
|
|
314 |
nlp = edsnlp.blank("fr") |
|
|
315 |
nlp.add_pipe("sentencizer") |
|
|
316 |
|
|
|
317 |
|
|
|
318 |
def test_rule_based_pipeline(): |
|
|
319 |
nlp = edsnlp.blank("eds") |
|
|
320 |
nlp.add_pipe("eds.normalizer") |
|
|
321 |
nlp.add_pipe("eds.covid") |
|
|
322 |
|
|
|
323 |
assert nlp.pipe_names == ["normalizer", "covid"] |
|
|
324 |
assert nlp.get_pipe("normalizer") == nlp.pipeline[0][1] |
|
|
325 |
assert nlp.has_pipe("covid") |
|
|
326 |
|
|
|
327 |
with pytest.raises(ValueError) as exc_info: |
|
|
328 |
nlp.get_pipe("unknown") |
|
|
329 |
|
|
|
330 |
assert str(exc_info.value) == "Pipe 'unknown' not found in pipeline." |
|
|
331 |
|
|
|
332 |
doc = nlp.make_doc("Mon patient a le covid") |
|
|
333 |
|
|
|
334 |
new_doc = nlp(doc) |
|
|
335 |
|
|
|
336 |
assert len(doc.ents) == 1 |
|
|
337 |
assert new_doc is doc |
|
|
338 |
|
|
|
339 |
assert nlp.get_pipe_meta("covid").assigns == ["doc.ents", "doc.spans"] |
|
|
340 |
|
|
|
341 |
|
|
|
342 |
def test_torch_save(ml_nlp): |
|
|
343 |
import torch |
|
|
344 |
|
|
|
345 |
ml_nlp.get_pipe("ner").update_labels(["LOC", "PER"]) |
|
|
346 |
buffer = BytesIO() |
|
|
347 |
torch.save(ml_nlp, buffer) |
|
|
348 |
buffer.seek(0) |
|
|
349 |
nlp = torch.load(buffer, weights_only=False) |
|
|
350 |
assert nlp.get_pipe("ner").labels == ["LOC", "PER"] |
|
|
351 |
assert len(list(nlp("Une phrase. Deux phrases.").sents)) == 2 |
|
|
352 |
|
|
|
353 |
|
|
|
354 |
def test_parameters(frozen_ml_nlp): |
|
|
355 |
assert len(list(frozen_ml_nlp.parameters())) == 42 |
|
|
356 |
|
|
|
357 |
|
|
|
358 |
def test_missing_factory(nlp): |
|
|
359 |
with pytest.raises(ValueError) as exc_info: |
|
|
360 |
nlp.add_pipe("__test_missing_pipe__") |
|
|
361 |
|
|
|
362 |
assert "__test_missing_pipe__" in str(exc_info.value) |
|
|
363 |
|
|
|
364 |
|
|
|
365 |
@edsnlp.registry.factory("custom-curry-test") |
|
|
366 |
class CustomComponent(BaseComponent): |
|
|
367 |
def __init__(self, nlp, name): |
|
|
368 |
self.nlp = nlp |
|
|
369 |
|
|
|
370 |
def __call__(self, doc): |
|
|
371 |
return doc |
|
|
372 |
|
|
|
373 |
|
|
|
374 |
def test_curried_nlp_pipe(): |
|
|
375 |
nlp = edsnlp.blank("eds") |
|
|
376 |
nlp.add_pipe(eds.sentences(name="my-sentences")) |
|
|
377 |
nlp.add_pipe(eds.normalizer()) |
|
|
378 |
nlp.add_pipe(eds.sections(), name="sections") |
|
|
379 |
pipe = CustomComponent() |
|
|
380 |
|
|
|
381 |
assert isinstance(pipe, CurriedFactory) |
|
|
382 |
err = ( |
|
|
383 |
f"This component CurriedFactory({pipe.factory}) has not been instantiated " |
|
|
384 |
f"yet, likely because it was missing an `nlp` pipeline argument. You should " |
|
|
385 |
f"either:\n" |
|
|
386 |
"- add it to a pipeline: `pipe = nlp.add_pipe(pipe)`\n" |
|
|
387 |
"- or fill its `nlp` argument: `pipe = factory(nlp=nlp, ...)`" |
|
|
388 |
) |
|
|
389 |
with pytest.raises(TypeError) as exc_info: |
|
|
390 |
pipe("Demo texte") |
|
|
391 |
assert str(exc_info.value) == err |
|
|
392 |
|
|
|
393 |
with pytest.raises(TypeError) as exc_info: |
|
|
394 |
pipe.forward("Demo texte") |
|
|
395 |
assert str(exc_info.value) == err |
|
|
396 |
|
|
|
397 |
nlp.add_pipe(pipe, name="custom") |
|
|
398 |
|
|
|
399 |
assert nlp.pipes.custom.nlp is nlp |
|
|
400 |
|
|
|
401 |
assert nlp.pipe_names == ["my-sentences", "normalizer", "sections", "custom"] |
|
|
402 |
|
|
|
403 |
|
|
|
404 |
@pytest.mark.skipif( |
|
|
405 |
sys.version_info < (3, 8), |
|
|
406 |
reason="Can't run on GH CI with Python 3.7", |
|
|
407 |
) |
|
|
408 |
@pytest.mark.skipif(torch is None, reason="torch not installed") |
|
|
409 |
def test_huggingface(): |
|
|
410 |
nlp = edsnlp.load( |
|
|
411 |
"AP-HP/dummy-ner", |
|
|
412 |
auto_update=True, |
|
|
413 |
install_dependencies=True, |
|
|
414 |
) |
|
|
415 |
doc = nlp("On lui prescrit du paracetamol à 500mg.") |
|
|
416 |
assert doc.ents[0].text == "paracetamol" |
|
|
417 |
assert doc.ents[1].text == "500mg" |
|
|
418 |
|
|
|
419 |
# Try loading it twice for coverage |
|
|
420 |
edsnlp.load( |
|
|
421 |
"AP-HP/dummy-ner", |
|
|
422 |
auto_update=True, |
|
|
423 |
install_dependencies=True, |
|
|
424 |
) |
|
|
425 |
|
|
|
426 |
subprocess.run(["pip", "uninstall", "dummy-pip-package", "-y"], check=True) |
|
|
427 |
|
|
|
428 |
|
|
|
429 |
@pytest.mark.skipif( |
|
|
430 |
sys.version_info < (3, 8), |
|
|
431 |
reason="Can't run on GH CI with Python 3.7", |
|
|
432 |
) |
|
|
433 |
def test_missing_huggingface(): |
|
|
434 |
with pytest.raises(ValueError) as exc_info: |
|
|
435 |
edsnlp.load( |
|
|
436 |
"AP-HP/does-not-exist", |
|
|
437 |
auto_update=True, |
|
|
438 |
) |
|
|
439 |
|
|
|
440 |
assert "The load function expects either :" in str(exc_info.value) |
|
|
441 |
|
|
|
442 |
|
|
|
443 |
def test_repr(frozen_ml_nlp): |
|
|
444 |
with frozen_ml_nlp.select_pipes(disable=["sentences"]): |
|
|
445 |
assert ( |
|
|
446 |
repr(frozen_ml_nlp) |
|
|
447 |
== """\ |
|
|
448 |
Pipeline(lang=eds, pipes={ |
|
|
449 |
"sentences": [disabled] eds.sentences, |
|
|
450 |
"transformer": eds.transformer, |
|
|
451 |
"ner": eds.ner_crf |
|
|
452 |
})""" |
|
|
453 |
) |
|
|
454 |
|
|
|
455 |
|
|
|
456 |
@edsnlp.registry.factory.register("test_nlp_less", spacy_compatible=False) |
|
|
457 |
class NlpLessComponent: |
|
|
458 |
def __init__(self, nlp=None, name: str = "nlp_less", *, value: int): |
|
|
459 |
self.value = value |
|
|
460 |
self.name = name |
|
|
461 |
|
|
|
462 |
def __call__(self, doc): |
|
|
463 |
return doc |
|
|
464 |
|
|
|
465 |
|
|
|
466 |
def test_nlp_less_component(): |
|
|
467 |
component = NlpLessComponent(value=42) |
|
|
468 |
assert component.value == 42 |
|
|
469 |
|
|
|
470 |
config = """ |
|
|
471 |
[component] |
|
|
472 |
@factory = "test_nlp_less" |
|
|
473 |
value = 42 |
|
|
474 |
""" |
|
|
475 |
component = Config.from_str(config).resolve(registry=registry)["component"] |
|
|
476 |
assert component.value == 42 |