Switch to unified view

a b/tests/pipelines/misc/test_quantities.py
1
from itertools import chain
2
3
import pytest
4
import spacy
5
from pytest import fixture, raises
6
from spacy.tokens.span import Span
7
8
from edsnlp.core import PipelineProtocol
9
from edsnlp.pipelines.misc.quantities import QuantitiesMatcher
10
11
text = (
12
    "Le patient fait 1 m 50 kg. La tumeur fait 2.0cm x 3cm. \n"
13
    "Une autre tumeur plus petite fait 2 par 1mm.\n"
14
    "Les trois éléments font 8, 13 et 15dm.\n"
15
    """
16
    Leucocytes ¦mm ¦ ¦4.2 ¦ ¦4.0-10.0
17
    Hémoglobine ¦ ¦9.0 - ¦ g ¦13-14
18
    Hémoglobine ¦ ¦9.0 - ¦ ¦ xxx
19
    """
20
)
21
22
23
@fixture
24
def blank_nlp():
25
    model = spacy.blank("eds")
26
    model.add_pipe("eds.normalizer")
27
    model.add_pipe("eds.sentences")
28
    model.add_pipe("eds.tables")
29
30
    return model
31
32
33
@fixture
34
def matcher(blank_nlp):
35
    matcher = QuantitiesMatcher(blank_nlp, extract_ranges=True, use_tables=True)
36
    return matcher
37
38
39
def test_deprecated_pipe(blank_nlp: PipelineProtocol):
40
    blank_nlp.add_pipe("matcher", config=dict(terms={"patient": "patient"}))
41
    blank_nlp.add_pipe(
42
        "eds.measurements",
43
    )
44
45
    doc = blank_nlp(text)
46
47
    assert len(doc.ents) == 1
48
49
    assert len(doc.spans["quantities"]) == 15
50
    assert len(doc.spans["measurements"]) == 15
51
52
53
def test_deprecated_arg(blank_nlp: PipelineProtocol):
54
    blank_nlp.add_pipe("matcher", config=dict(terms={"patient": "patient"}))
55
    blank_nlp.add_pipe(
56
        "eds.measurements", config=dict(measurements=["size", "weight", "bmi"])
57
    )
58
59
    doc = blank_nlp(text)
60
61
    assert len(doc.ents) == 1
62
63
    assert len(doc.spans["quantities"]) == 15
64
    assert len(doc.spans["measurements"]) == 15
65
66
67
def test_default_factory(blank_nlp: PipelineProtocol):
68
    blank_nlp.add_pipe("matcher", config=dict(terms={"patient": "patient"}))
69
    blank_nlp.add_pipe(
70
        "eds.quantities",
71
        config={"quantities": ["size", "weight", "bmi"], "use_tables": True},
72
    )
73
74
    doc = blank_nlp(text)
75
76
    assert len(doc.ents) == 1
77
78
    assert len(doc.spans["quantities"]) == 15
79
80
81
def test_quantities_component(
82
    blank_nlp: PipelineProtocol,
83
    matcher: QuantitiesMatcher,
84
):
85
    doc = blank_nlp(text)
86
87
    with raises(KeyError):
88
        doc.spans["quantities"]
89
90
    doc = matcher(doc)
91
92
    for span_key in ["quantities", "measurements"]:
93
        m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13 = doc.spans[span_key]
94
95
        assert str(m1._.value) == "1 m"
96
        assert str(m2._.value) == "50 kg"
97
        assert str(m3._.value) == "2.0 cm"
98
        assert str(m4._.value) == "3 cm"
99
        assert str(m5._.value) == "2 mm"
100
        assert str(m6._.value) == "1 mm"
101
        assert str(m7._.value) == "8 dm"
102
        assert str(m8._.value) == "13 dm"
103
        assert str(m9._.value) == "15 dm"
104
        assert str(m10._.value) == "4.2 mm"
105
        assert str(m11._.value) == "4.0-10.0 mm"
106
        assert str(m12._.value) == "9.0 g"
107
        assert str(m13._.value) == "13-14 g"
108
109
110
def test_quantities_component_scaling(
111
    blank_nlp: PipelineProtocol,
112
    matcher: QuantitiesMatcher,
113
):
114
    doc = blank_nlp(text)
115
116
    with raises(KeyError):
117
        doc.spans["quantities"]
118
119
    doc = matcher(doc)
120
121
    m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13 = doc.spans["quantities"]
122
123
    assert abs(m1._.value.cm - 100) < 1e-6
124
    assert abs(m2._.value.mg - 50000000.0) < 1e-6
125
    assert abs(m3._.value.mm - 20) < 1e-6
126
    assert abs(m4._.value.mm - 30) < 1e-6
127
    assert abs(m5._.value.cm - 0.2) < 1e-6
128
    assert abs(m6._.value.cm - 0.1) < 1e-6
129
    assert abs(m7._.value.dm - 8.0) < 1e-6
130
    assert abs(m8._.value.m - 1.3) < 1e-6
131
    assert abs(m9._.value.m - 1.5) < 1e-6
132
    assert abs(m10._.value.mm - 4.2) < 1e-6
133
    assert abs(m11._.value.mm[0] - 4.0) < 1e-6
134
    assert abs(m11._.value.mm[1] - 10.0) < 1e-6
135
    assert abs(m12._.value.g - 9) < 1e-6
136
    assert abs(m13._.value.g[0] - 13.0) < 1e-6
137
    assert abs(m13._.value.g[1] - 14.0) < 1e-6
138
139
140
def test_measure_label(
141
    blank_nlp: PipelineProtocol,
142
    matcher: QuantitiesMatcher,
143
):
144
    doc = blank_nlp(text)
145
    doc = matcher(doc)
146
147
    m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13 = doc.spans["quantities"]
148
149
    assert m1.label_ == "size"
150
    assert m2.label_ == "weight"
151
    assert m3.label_ == "size"
152
    assert m4.label_ == "size"
153
    assert m5.label_ == "size"
154
    assert m6.label_ == "size"
155
    assert m7.label_ == "size"
156
    assert m8.label_ == "size"
157
    assert m9.label_ == "size"
158
    assert m10.label_ == "size"
159
    assert m11.label_ == "size"
160
    assert m12.label_ == "weight"
161
    assert m13.label_ == "weight"
162
163
164
def test_quantities_all_input(
165
    blank_nlp: PipelineProtocol,
166
    matcher: QuantitiesMatcher,
167
):
168
    all_text = "On mesure 13 mol/ml de ..." "On compte 16x10*9 ..."
169
    blank_nlp.add_pipe(
170
        "eds.quantities",
171
        config={"quantities": "all", "extract_ranges": True},
172
    )
173
174
    doc = blank_nlp(all_text)
175
176
    m1, m2 = doc.spans["quantities"]
177
178
    assert str(m1._.value) == "13 mol_per_ml"
179
    assert str(m2._.value) == "16 x10*9"
180
181
182
def test_measure_str(
183
    blank_nlp: PipelineProtocol,
184
    matcher: QuantitiesMatcher,
185
):
186
    for text, res in [
187
        ("1m50", "1.5 m"),
188
        ("1,50cm", "1.5 cm"),
189
    ]:
190
        doc = blank_nlp(text)
191
        doc = matcher(doc)
192
193
        assert str(doc.spans["quantities"][0]._.value) == res
194
195
196
def test_measure_repr(
197
    blank_nlp: PipelineProtocol,
198
    matcher: QuantitiesMatcher,
199
):
200
    for text, res in [
201
        (
202
            "1m50",
203
            "Quantity(1.5, 'm')",
204
        ),
205
        (
206
            "1,50cm",
207
            "Quantity(1.5, 'cm')",
208
        ),
209
    ]:
210
        doc = blank_nlp(text)
211
        doc = matcher(doc)
212
213
        assert repr(doc.spans["quantities"][0]._.value) == res
214
215
216
def test_compare(
217
    blank_nlp: PipelineProtocol,
218
    matcher: QuantitiesMatcher,
219
):
220
    m1, m2 = "1m0", "120cm"
221
    m1 = matcher(blank_nlp(m1)).spans["quantities"][0]
222
    m2 = matcher(blank_nlp(m2)).spans["quantities"][0]
223
    assert m1._.value <= m2._.value
224
    assert m2._.value > m1._.value
225
226
    m3 = "Entre deux et trois metres"
227
    m4 = "De 2 à 3 metres"
228
    m3 = matcher(blank_nlp(m3)).spans["quantities"][0]
229
    m4 = matcher(blank_nlp(m4)).spans["quantities"][0]
230
    assert str(m3._.value) == "2-3 m"
231
    assert str(m4._.value) == "2-3 m"
232
    assert m4._.value.cm == (200.0, 300.0)
233
234
    assert m3._.value == m4._.value
235
    assert m3._.value <= m4._.value
236
    assert m3._.value >= m1._.value
237
238
    assert max(list(chain(m1._.value, m2._.value, m3._.value, m4._.value))).cm == 300
239
240
241
def test_unitless(
242
    blank_nlp: PipelineProtocol,
243
    matcher: QuantitiesMatcher,
244
):
245
    for text, res in [
246
        ("BMI: 24 .", "24 kg_per_m2"),
247
        ("Le patient mesure 1.5 ", "1.5 m"),
248
        ("Le patient mesure 152 ", "152 cm"),
249
        ("Le patient pèse 34 ", "34 kg"),
250
    ]:
251
        doc = blank_nlp(text)
252
        doc = matcher(doc)
253
254
        assert str(doc.spans["quantities"][0]._.value) == res
255
256
257
def test_non_matches(
258
    blank_nlp: PipelineProtocol,
259
    matcher: QuantitiesMatcher,
260
):
261
    for text in [
262
        "On délivre à 10 g / h.",
263
        "Le patient grandit de 10 cm par jour ",
264
        "Truc 10cma truc",
265
        "01.42.43.56.78 m",
266
    ]:
267
        doc = blank_nlp(text)
268
        doc = matcher(doc)
269
270
        assert len(doc.spans["quantities"]) == 0
271
272
273
def test_numbers(
274
    blank_nlp: PipelineProtocol,
275
    matcher: QuantitiesMatcher,
276
):
277
    for text, res in [
278
        ("deux m", "2 m"),
279
        ("2 m", "2 m"),
280
        ("⅛ m", "0.125 m"),
281
        ("0 m", "0 m"),
282
        ("55 @ 77777 cm", "77777 cm"),
283
    ]:
284
        doc = blank_nlp(text)
285
        doc = matcher(doc)
286
287
        assert str(doc.spans["quantities"][0]._.value) == res
288
289
290
def test_ranges(
291
    blank_nlp: PipelineProtocol,
292
    matcher: QuantitiesMatcher,
293
):
294
    for text, res, snippet in [
295
        ("Le patient fait entre 1 et 2m", "1-2 m", "entre 1 et 2m"),
296
        ("On mesure de 2 à 2.5 dl d'eau", "2-2.5 dl", "de 2 à 2.5 dl"),
297
    ]:
298
        doc = blank_nlp(text)
299
        doc = matcher(doc)
300
301
        quantity = doc.spans["quantities"][0]
302
        assert str(quantity._.value) == res
303
        assert quantity.text == snippet
304
305
306
def test_merge_align(blank_nlp, matcher):
307
    matcher.merge_mode = "align"
308
    matcher.span_getter = {"candidates": True}
309
    matcher.span_setter = {"ents": True}
310
    doc = blank_nlp(text)
311
    ent = Span(doc, 10, 15, label="size")
312
    doc.spans["candidates"] = [ent]
313
    doc = matcher(doc)
314
315
    assert len(doc.ents) == 1
316
    assert str(ent._.value) == "2.0 cm"
317
318
319
def test_merge_intersect(blank_nlp, matcher: QuantitiesMatcher):
320
    matcher.merge_mode = "intersect"
321
    matcher.span_setter = {**matcher.span_setter, "ents": True}
322
    matcher.span_getter = {"lookup_zones": True}
323
    doc = blank_nlp(text)
324
    ent = Span(doc, 10, 16, label="size")
325
    doc.spans["lookup_zones"] = [ent]
326
    doc = matcher(doc)
327
328
    assert len(doc.ents) == 2
329
    assert len(doc.spans["quantities"]) == 2
330
    assert [doc.ents[0].text, doc.ents[1].text] == ["2.0cm", "3cm"]
331
    assert [doc.ents[0]._.value.cm, doc.ents[1]._.value.cm] == [2.0, 3]
332
333
334
def test_quantity_snippets(blank_nlp, matcher: QuantitiesMatcher):
335
    for text, result in [
336
        ("0.50g", ["0.5 g"]),
337
        ("0.050g", ["0.05 g"]),
338
        ("1 m 50", ["1.5 m"]),
339
        ("1.50 m", ["1.5 m"]),
340
        ("1,50m", ["1.5 m"]),
341
        ("2.0cm x 3cm", ["2.0 cm", "3 cm"]),
342
        ("2 par 1mm", ["2 mm", "1 mm"]),
343
        ("8, 13 et 15dm", ["8 dm", "13 dm", "15 dm"]),
344
        ("1 / 50  kg", ["0.02 kg"]),
345
    ]:
346
        doc = blank_nlp(text)
347
        doc = matcher(doc)
348
349
        assert [str(span._.value) for span in doc.spans["quantities"]] == result
350
351
352
def test_error_management(blank_nlp, matcher: QuantitiesMatcher):
353
    text = """
354
        Leucocytes ¦ ¦ ¦4.2 ¦ ¦4.0-10.0
355
        Hémoglobine ¦ ¦9.0 - ¦ ¦13-14
356
        """
357
    doc = blank_nlp(text)
358
    doc = matcher(doc)
359
360
    assert len(doc.spans["quantities"]) == 0
361
362
363
def test_conversions(blank_nlp, matcher: QuantitiesMatcher):
364
    tests = [
365
        ("20 dm3", "l", 20),
366
        ("20 dm3", "m3", 0.02),
367
        ("20 dm3", "cm3", 20000),
368
        ("10 l", "cm3", 10000),
369
        ("10 l", "cl", 1000),
370
        ("25 kg/m2", "kg_per_cm2", 0.0025),
371
        ("2.4 x10*9µl", "l", 2400),
372
    ]
373
374
    for text, unit, expected in tests:
375
        doc = blank_nlp(text)
376
        doc = matcher(doc)
377
        result = getattr(doc.spans["quantities"][0]._.value, unit)
378
        assert result == pytest.approx(
379
            expected, 1e-6
380
        ), f"{result} != {expected} for {text} in {unit}"