[cad161]: / tests / pipelines / trainable / test_extractive_qa.py

Download this file

102 lines (93 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
from spacy.tokens import Span
import edsnlp
import edsnlp.pipes as eds
def test_ner():
nlp = edsnlp.blank("eds")
nlp.add_pipe(
eds.extractive_qa(
embedding=eds.transformer(
model="prajjwal1/bert-tiny",
window=20,
stride=10,
),
# During training, where do we get the gold entities from ?
target_span_getter=["ner-gold"],
# Which prompts for each label ?
questions={
"PERSON": "Quels sont les personnages ?",
"GIFT": "Quels sont les cadeaux ?",
},
questions_attribute="question",
# During prediction, where do we set the predicted entities ?
span_setter="ents",
),
)
doc = nlp(
"L'aîné eut le Moulin, le second eut l'âne, et le plus jeune n'eut que le Chat."
)
doc._.question = {
"FAVORITE": ["Qui a eu de l'argent ?"],
}
# doc[0:2], doc[4:5], doc[6:8], doc[9:11], doc[13:16], doc[20:21]
doc.spans["ner-gold"] = [
Span(doc, 0, 2, "PERSON"), # L'aîné
Span(doc, 4, 5, "GIFT"), # Moulin
Span(doc, 6, 8, "PERSON"), # le second
Span(doc, 9, 11, "GIFT"), # l'âne
Span(doc, 13, 16, "PERSON"), # le plus jeune
Span(doc, 20, 21, "GIFT"), # Chat
]
nlp.post_init([doc])
ner = nlp.pipes.extractive_qa
batch = ner.prepare_batch([doc], supervision=True)
results = ner.module_forward(batch)
list(ner.pipe([doc]))[0]
assert results["loss"] is not None
trf_inputs = [
seq.replace(" [PAD]", "")
for seq in ner.embedding.tokenizer.batch_decode(batch["embedding"]["input_ids"])
]
assert trf_inputs == [
"[CLS] quels sont les cadeaux? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501
"[CLS] quels sont les cadeaux? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501
"[CLS] quels sont les cadeaux? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501
"[CLS] quels sont les personnages? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501
"[CLS] quels sont les personnages? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501
"[CLS] quels sont les personnages? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501
"[CLS] qui a eu de l'argent? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501
"[CLS] qui a eu de l'argent? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501
"[CLS] qui a eu de l'argent? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501
]
assert batch["targets"].squeeze(2).tolist() == [
[0, 0, 0, 0, 4, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0],
[2, 3, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 2, 1, 3, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
assert nlp.config.to_yaml_str() == (
"nlp:\n"
" lang: eds\n"
" pipeline:\n"
" - extractive_qa\n"
" tokenizer:\n"
" '@tokenizers': eds.tokenizer\n"
"components:\n"
" extractive_qa:\n"
" '@factory': eds.extractive_qa\n"
" embedding:\n"
" '@factory': eds.transformer\n"
" model: prajjwal1/bert-tiny\n"
" window: 20\n"
" stride: 10\n"
" questions:\n"
" PERSON: Quels sont les personnages ?\n"
" GIFT: Quels sont les cadeaux ?\n"
" questions_attribute: question\n"
" target_span_getter:\n"
" - ner-gold\n"
" span_setter:\n"
" ents: true\n"
" infer_span_setter: false\n"
" mode: joint\n"
" window: 40\n"
" stride: 20\n"
)