[cad161]: / tests / pipelines / trainable / test_ner.py

Download this file

70 lines (63 with data), 1.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pytest
from pytest import mark
from spacy.tokens import Span
import edsnlp
@mark.parametrize(
"ner_mode,window",
[
("independent", 1),
("joint", 0),
("joint", 5),
("marginal", 0),
],
)
def test_ner(ner_mode, window):
nlp = edsnlp.blank("eds")
nlp.add_pipe(
"eds.transformer",
name="transformer",
config=dict(
model="prajjwal1/bert-tiny",
window=128,
stride=96,
),
)
nlp.add_pipe(
"eds.ner_crf",
name="ner",
config=dict(
embedding=nlp.get_pipe("transformer"),
mode=ner_mode,
target_span_getter=["ents", "ner-preds"],
window=window,
),
)
nlp.pipes.ner.compute_confidence_score = True
ner = nlp.get_pipe("ner")
ner.update_labels([])
doc = nlp(
"L'aîné eut le Moulin, le second eut l'âne, et le plus jeune n'eut que le Chat."
)
ner.labels = ["LOC", "ORG"]
# doc[0:2], doc[4:5], doc[6:8], doc[9:11], doc[13:16], doc[20:21]
doc.ents = [
Span(doc, 0, 2, "PERSON"),
Span(doc, 4, 5, "GIFT"),
Span(doc, 6, 8, "PERSON"),
Span(doc, 9, 11, "GIFT"),
Span(doc, 13, 16, "PERSON"),
Span(doc, 20, 21, "GIFT"),
]
with pytest.warns() as record:
nlp.post_init([doc])
assert len(record) == 1
assert record[0].message.args[0] == (
"The labels inferred from the data are different from the labels passed to "
"the component. Differing labels are ['GIFT', 'LOC', 'ORG', 'PERSON']"
)
ner = nlp.get_pipe("ner")
ner.update_labels(["PERSON", "GIFT"])
batch = ner.prepare_batch([doc], supervision=True)
batch = ner(batch)
list(ner.pipe([doc]))
assert batch["loss"] is not None