--- a +++ b/tests/pipelines/trainable/test_ner.py @@ -0,0 +1,69 @@ +import pytest +from pytest import mark +from spacy.tokens import Span + +import edsnlp + + +@mark.parametrize( + "ner_mode,window", + [ + ("independent", 1), + ("joint", 0), + ("joint", 5), + ("marginal", 0), + ], +) +def test_ner(ner_mode, window): + nlp = edsnlp.blank("eds") + nlp.add_pipe( + "eds.transformer", + name="transformer", + config=dict( + model="prajjwal1/bert-tiny", + window=128, + stride=96, + ), + ) + nlp.add_pipe( + "eds.ner_crf", + name="ner", + config=dict( + embedding=nlp.get_pipe("transformer"), + mode=ner_mode, + target_span_getter=["ents", "ner-preds"], + window=window, + ), + ) + nlp.pipes.ner.compute_confidence_score = True + ner = nlp.get_pipe("ner") + ner.update_labels([]) + doc = nlp( + "L'aîné eut le Moulin, le second eut l'âne, et le plus jeune n'eut que le Chat." + ) + ner.labels = ["LOC", "ORG"] + # doc[0:2], doc[4:5], doc[6:8], doc[9:11], doc[13:16], doc[20:21] + doc.ents = [ + Span(doc, 0, 2, "PERSON"), + Span(doc, 4, 5, "GIFT"), + Span(doc, 6, 8, "PERSON"), + Span(doc, 9, 11, "GIFT"), + Span(doc, 13, 16, "PERSON"), + Span(doc, 20, 21, "GIFT"), + ] + with pytest.warns() as record: + nlp.post_init([doc]) + assert len(record) == 1 + assert record[0].message.args[0] == ( + "The labels inferred from the data are different from the labels passed to " + "the component. Differing labels are ['GIFT', 'LOC', 'ORG', 'PERSON']" + ) + + ner = nlp.get_pipe("ner") + ner.update_labels(["PERSON", "GIFT"]) + batch = ner.prepare_batch([doc], supervision=True) + batch = ner(batch) + + list(ner.pipe([doc])) + + assert batch["loss"] is not None