a b/tests/labelers/test_CodeLabelers.py
1
import datetime
2
import pathlib
3
from typing import List, Set
4
5
# Needed to import `tools` for local testing
6
from femr_test_tools import EventsWithLabels, run_test_for_labeler
7
8
from femr.labelers import TimeHorizon
9
from femr.labelers.omop import (
10
    AKICodeLabeler,
11
    AnemiaCodeLabeler,
12
    CodeLabeler,
13
    HyperkalemiaCodeLabeler,
14
    HypoglycemiaCodeLabeler,
15
    HyponatremiaCodeLabeler,
16
    LupusCodeLabeler,
17
    MortalityCodeLabeler,
18
    NeutropeniaCodeLabeler,
19
    OMOPConceptCodeLabeler,
20
    ThrombocytopeniaCodeLabeler,
21
)
22
23
#############################################
24
#############################################
25
#
26
# Generic CodeLabeler
27
#
28
#############################################
29
#############################################
30
31
32
def test_outcome_codes(tmp_path: pathlib.Path):
33
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=10))
34
    # One outcome
35
    labeler = CodeLabeler(["2"], time_horizon)
36
    events_with_labels: EventsWithLabels = [
37
        (((2015, 1, 3), 2, None), "duplicate"),
38
        (((2015, 1, 3), 4, None), "duplicate"),
39
        (((2015, 1, 3), 1, None), "duplicate"),
40
        (((2015, 1, 3), 3, None), "skip"),
41
        (((2015, 10, 5), 1, None), False),
42
        (((2018, 1, 3), 2, None), "skip"),
43
        (((2018, 3, 1), 4, None), False),
44
        (((2018, 3, 3), 1, None), False),
45
        (((2018, 5, 2), 5, None), True),
46
        (((2018, 5, 3), 2, None), "skip"),
47
        (((2018, 5, 3, 11), 1, None), False),
48
        (((2018, 5, 4), 1, None), "duplicate"),
49
        (((2018, 5, 4), 4, None), False),
50
        (((2018, 11, 1), 5, None), False),
51
        (((2018, 12, 4), 1, None), False),
52
        (((2018, 12, 30), 4, None), "out of range"),
53
    ]
54
    run_test_for_labeler(labeler, events_with_labels, help_text="test_outcome_codes_one")
55
56
    # Zero outcome
57
    labeler = CodeLabeler([], time_horizon)
58
    events_with_labels = [
59
        (((2015, 1, 3), 2, None), "duplicate"),
60
        (((2015, 1, 3), 4, None), "duplicate"),
61
        (((2015, 1, 3), 1, None), "duplicate"),
62
        (((2015, 1, 3), 3, None), False),
63
        (((2015, 10, 5), 1, None), False),
64
        (((2018, 1, 3), 2, None), False),
65
        (((2018, 3, 1), 4, None), False),
66
        (((2018, 3, 3), 1, None), False),
67
        (((2018, 5, 2), 5, None), False),
68
        (((2018, 5, 3), 2, None), False),
69
        (((2018, 5, 3, 11), 1, None), False),
70
        (((2018, 5, 4), 1, None), "duplicate"),
71
        (((2018, 5, 4), 4, None), False),
72
        (((2018, 11, 1), 5, None), False),
73
        (((2018, 12, 4), 1, None), False),
74
        (((2018, 12, 30), 4, None), "out of range"),
75
    ]
76
    run_test_for_labeler(labeler, events_with_labels, help_text="test_outcome_codes_zero")
77
78
    # Multiple outcomes
79
    labeler = CodeLabeler(["1", "4"], time_horizon)
80
    events_with_labels = [
81
        (((2015, 1, 3), 2, None), "duplicate"),
82
        (((2015, 1, 3), 4, None), "duplicate"),
83
        (((2015, 1, 3), 1, None), "duplicate"),
84
        (((2015, 1, 3), 3, None), "skip"),
85
        (((2015, 10, 5), 1, None), "skip"),
86
        (((2018, 1, 3), 2, None), False),
87
        (((2018, 3, 1), 4, None), "skip"),
88
        (((2018, 3, 3), 1, None), "skip"),
89
        (((2018, 5, 2), 5, None), False),
90
        (((2018, 5, 3), 2, None), False),
91
        (((2018, 6, 2), 0, None), True),
92
        (((2018, 6, 3, 11), 1, None), "skip"),
93
        (((2018, 6, 3, 23), 3, None), False),
94
        (((2018, 9, 1), 3, None), True),
95
        (((2018, 9, 4), 4, None), "skip"),
96
        (((2018, 11, 1), 5, None), False),
97
        (((2018, 12, 3), 0, None), True),
98
        (((2018, 12, 4), 4, None), "skip"),
99
        (((2018, 12, 30), 0, None), "out of range"),
100
    ]
101
    run_test_for_labeler(labeler, events_with_labels, help_text="test_outcome_codes_multiple")
102
103
104
def test_prediction_codes(tmp_path: pathlib.Path):
105
    # One outcome + multiple predictions
106
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=10))
107
    labeler = CodeLabeler(["2"], time_horizon, prediction_codes=["4", "5"])
108
    events_with_labels: EventsWithLabels = [
109
        (((2015, 1, 3), 2, None), "skip"),
110
        (((2015, 1, 3), 4, None), "skip"),
111
        (((2015, 1, 3), 1, None), "skip"),
112
        (((2015, 1, 3), 3, None), "skip"),
113
        (((2015, 10, 5), 1, None), "skip"),
114
        (((2018, 1, 3), 2, None), "skip"),
115
        (((2018, 3, 1), 4, None), False),
116
        (((2018, 3, 3), 1, None), "skip"),
117
        (((2018, 5, 2), 5, None), True),
118
        (((2018, 5, 3), 2, None), "skip"),
119
        (((2018, 5, 3, 11), 1, None), "skip"),
120
        (((2018, 5, 4), 4, None), False),
121
        (((2018, 5, 4), 1, None), "skip"),
122
        (((2018, 11, 1), 5, None), False),
123
        (((2018, 12, 4), 1, None), "skip"),
124
        (((2018, 12, 30), 4, None), "out of range"),
125
    ]
126
    run_test_for_labeler(labeler, events_with_labels, help_text="prediction_codes_one_outcomes")
127
128
    # Multiple outcomes + multiple predictions
129
    labeler = CodeLabeler(["2", "6", "7"], time_horizon, prediction_codes=["4", "5"])
130
    events_with_labels = [
131
        (((2010, 1, 1), 2, None), "skip"),
132
        (((2010, 1, 3), 4, None), True),
133
        (((2010, 1, 8), 6, None), "skip"),
134
        (((2010, 2, 1), 5, None), True),
135
        (((2010, 2, 9), 7, None), "skip"),
136
        (((2010, 2, 11), 4, None), False),
137
        (((2015, 1, 3), 2, None), "skip"),
138
        (((2015, 1, 3), 4, None), "skip"),
139
        (((2015, 1, 3), 1, None), "skip"),
140
        (((2015, 1, 3), 3, None), "skip"),
141
        (((2015, 10, 5), 1, None), "skip"),
142
        (((2018, 1, 3), 2, None), "skip"),
143
        (((2018, 3, 1), 4, None), True),
144
        (((2018, 3, 2), 7, None), "skip"),
145
        (((2018, 3, 3), 1, None), "skip"),
146
        (((2018, 5, 2), 5, None), True),
147
        (((2018, 5, 3), 2, None), "skip"),
148
        (((2018, 5, 3, 11), 1, None), "skip"),
149
        (((2018, 5, 4), 4, None), False),
150
        (((2018, 5, 4), 1, None), "skip"),
151
        (((2018, 11, 1), 5, None), False),
152
        (((2018, 12, 4), 1, None), "skip"),
153
        (((2018, 12, 30), 4, None), "out of range"),
154
    ]
155
    run_test_for_labeler(
156
        labeler,
157
        events_with_labels,
158
        help_text="prediction_codes_multiple_outcomes",
159
    )
160
161
    # Multiple outcomes + no predictions
162
    labeler = CodeLabeler(["2", "6", "7"], time_horizon, prediction_codes=[])
163
    events_with_labels = [
164
        (((2010, 1, 1), 2, None), "skip"),
165
        (((2010, 1, 3), 4, None), "skip"),
166
        (((2010, 1, 8), 6, None), "skip"),
167
        (((2010, 2, 1), 5, None), "skip"),
168
        (((2010, 2, 9), 7, None), "skip"),
169
        (((2010, 2, 11), 4, None), "skip"),
170
        (((2015, 1, 3), 2, None), "skip"),
171
        (((2015, 1, 3), 4, None), "skip"),
172
        (((2015, 1, 3), 1, None), "skip"),
173
        (((2015, 1, 3), 3, None), "skip"),
174
        (((2015, 10, 5), 1, None), "skip"),
175
        (((2018, 1, 3), 2, None), "skip"),
176
        (((2018, 3, 1), 4, None), "skip"),
177
        (((2018, 3, 2), 7, None), "skip"),
178
        (((2018, 3, 3), 1, None), "skip"),
179
        (((2018, 5, 2), 5, None), "skip"),
180
        (((2018, 5, 3), 2, None), "skip"),
181
        (((2018, 5, 3, 11), 1, None), "skip"),
182
        (((2018, 5, 4), 4, None), "skip"),
183
        (((2018, 5, 4), 1, None), "skip"),
184
        (((2018, 11, 1), 5, None), "skip"),
185
        (((2018, 12, 4), 1, None), "skip"),
186
        (((2018, 12, 30), 4, None), "skip"),
187
    ]
188
    run_test_for_labeler(
189
        labeler,
190
        events_with_labels,
191
        help_text="prediction_codes_zero_predictions",
192
    )
193
194
195
#############################################
196
#############################################
197
#
198
# Generic OMOPConceptCodeLabeler
199
#
200
#############################################
201
#############################################
202
203
204
class DummyOntology_Base:
205
    def get_children(self, code: str) -> Set[str]:
206
        return set()
207
208
    def get_all_children(self, code: str) -> Set[str]:
209
        val = {code}
210
        for child in self.get_children(code):
211
            val |= self.get_all_children(child)
212
        return val
213
214
215
class DummyOntology_OMOPConcept(DummyOntology_Base):
216
    def get_children(self, parent_code: str) -> Set[str]:
217
        if parent_code == "OMOP_CONCEPT_A":
218
            return {"OMOP_CONCEPT_A_CHILD", "OMOP_CONCEPT_A_CHILD2"}
219
        elif parent_code == "OMOP_CONCEPT_B":
220
            return {"OMOP_CONCEPT_B_CHILD"}
221
        elif parent_code == "OMOP_CONCEPT_A_CHILD":
222
            return {"OMOP_CONCEPT_A_CHILD_CHILD"}
223
        else:
224
            return set()
225
226
227
class DummyLabeler_OMOPConcept(OMOPConceptCodeLabeler):
228
    original_omop_concept_codes = [
229
        "OMOP_CONCEPT_A",
230
        "OMOP_CONCEPT_B",
231
    ]
232
233
234
def test_omop_concept_code_labeler(tmp_path: pathlib.Path):
235
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=10))
236
    ontology = DummyOntology_OMOPConcept()
237
    labeler = DummyLabeler_OMOPConcept(ontology, time_horizon, prediction_codes=["1", "2"])  # type: ignore
238
    assert set(labeler.outcome_codes) == {
239
        "OMOP_CONCEPT_A_CHILD_CHILD",
240
        "OMOP_CONCEPT_B",
241
        "OMOP_CONCEPT_B_CHILD",
242
        "OMOP_CONCEPT_A_CHILD2",
243
        "OMOP_CONCEPT_A",
244
        "OMOP_CONCEPT_A_CHILD",
245
    }
246
    assert labeler.prediction_codes == ["1", "2"]
247
    assert labeler.get_time_horizon() == time_horizon
248
249
250
#############################################
251
#############################################
252
#
253
# Specific instances of CodeLabeler
254
#
255
#############################################
256
#############################################
257
258
259
#############################################
260
# MortalityCodeLabeler
261
#############################################
262
263
264
class DummyOntology_Mortality(DummyOntology_Base):
265
    def get_children(self, parent_code: str) -> Set[str]:
266
        return set()
267
268
269
def test_MortalityCodeLabeler() -> None:
270
    """Create a MortalityCodeLabeler for codes 3 and 6"""
271
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180))
272
    events_with_labels: EventsWithLabels = [
273
        (((1995, 1, 3), 0, 34.5), False),
274
        (((2000, 1, 1), 1, "test_value"), True),
275
        (((2000, 1, 5), 2, 1), True),
276
        (((2000, 6, 5), "SNOMED/419620001", True), "skip"),
277
        (((2005, 2, 5), 2, None), False),
278
        (((2005, 7, 5), 2, None), False),
279
        (((2010, 10, 5), 1, None), False),
280
        (((2015, 2, 5, 0), 2, None), False),
281
        (((2015, 7, 5, 0), 0, None), True),
282
        (((2015, 11, 5, 10, 10), 2, None), True),
283
        (((2015, 11, 15, 11), "SNOMED/419620001", None), "skip"),
284
        (((2020, 1, 1), 2, None), "out of range"),
285
        (((2020, 3, 1, 10, 10, 10), 2, None), "out of range"),
286
    ]
287
288
    ontology = DummyOntology_Mortality()
289
290
    # Run labeler
291
    labeler = MortalityCodeLabeler(ontology, time_horizon)  # type: ignore
292
293
    run_test_for_labeler(labeler, events_with_labels, help_text="MortalityLabeler")
294
295
296
#############################################
297
# LupusCodeLabeler
298
#############################################
299
300
301
class DummyOntology_Lupus(DummyOntology_Base):
302
    def get_children(self, parent_code: str) -> Set[str]:
303
        if parent_code == "SNOMED/55464009":
304
            return {"SNOMED_55464009", "Lupus_child_seven", "Lupus_child_nine", "Lupus_child_ten"}
305
        else:
306
            return set()
307
308
309
def test_LupusCodeLabeler() -> None:
310
    """Create a LupusCodeLabeler for codes 3 and 6"""
311
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180))
312
    events_with_labels: EventsWithLabels = [
313
        (((1995, 1, 3), 0, 34.5), False),
314
        (((2000, 1, 1), 1, "test_value"), True),
315
        (((2000, 1, 5), 2, 1), True),
316
        (((2000, 5, 5), "SNOMED/201436003", None), "skip"),
317
        (((2005, 2, 5), 2, None), False),
318
        (((2005, 7, 5), 2, None), False),
319
        (((2010, 10, 5), 1, None), True),
320
        (((2010, 10, 8), "Lupus_child_seven", None), "skip"),
321
        (((2015, 2, 5, 0), 2, None), False),
322
        (((2015, 7, 5, 0), 0, None), True),
323
        (((2015, 11, 5, 10, 10), 2, None), True),
324
        (((2015, 11, 15, 11), "SNOMED/55464009", None), "skip"),
325
        (((2020, 1, 1), "Lupus_child_ten", None), "skip"),
326
        (((2020, 3, 1, 10, 10, 10), 2, None), "out of range"),
327
    ]
328
329
    ontology = DummyOntology_Lupus()
330
    labeler = LupusCodeLabeler(ontology, time_horizon)  # type: ignore
331
    # Check that we selected the right codes
332
    assert set(labeler.outcome_codes) == {
333
        "SNOMED_55464009",
334
        "SNOMED/201436003",
335
        "Lupus_child_nine",
336
        "SNOMED/55464009",
337
        "Lupus_child_ten",
338
        "Lupus_child_seven",
339
    }
340
341
    run_test_for_labeler(labeler, events_with_labels, help_text="LupusCodeLabeler")
342
343
344
#############################################
345
#############################################
346
#
347
# Specific instances of OMOPConceptCodeLabeler
348
#
349
#############################################
350
#############################################
351
352
353
class DummyOntology_OMOPConcept_Specific(DummyOntology_Base):
354
    def __init__(self, new_codes: List[str]):
355
        self.new_codes = new_codes + ["", ""]
356
357
    def get_children(self, parent_code: str) -> Set[str]:
358
        if parent_code == "child_1":
359
            return {"child_1_1"}
360
        elif parent_code == self.new_codes[0]:
361
            return {"child_1"}
362
        elif parent_code == self.new_codes[1]:
363
            return {"child_2"}
364
        return set()
365
366
367
def _assert_labvalue_constructor_correct(
368
    labeler: OMOPConceptCodeLabeler,
369
    time_horizon: TimeHorizon,
370
    outcome_codes: set,
371
):
372
    assert set(labeler.outcome_codes) == outcome_codes
373
    assert labeler.prediction_codes == ["1", "2"]
374
    assert labeler.get_time_horizon() == time_horizon
375
376
377
def _create_specific_labvalue_labeler(LabelerClass, outcome_codes: set):
378
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=10))
379
    ontology = DummyOntology_OMOPConcept_Specific(LabelerClass.original_omop_concept_codes)
380
    labeler = LabelerClass(ontology, time_horizon, prediction_codes=["1", "2"])  # type: ignore
381
    _assert_labvalue_constructor_correct(labeler, time_horizon, outcome_codes)
382
    return labeler
383
384
385
def test_thrombocytopenia(tmp_path: pathlib.Path):
386
    outcome_codes = {"child_1_1", "child_2", "SNOMED/89627008", "child_1", "SNOMED/267447008"}
387
    _create_specific_labvalue_labeler(ThrombocytopeniaCodeLabeler, outcome_codes)
388
389
390
def test_hyperkalemia(tmp_path: pathlib.Path):
391
    outcome_codes = {"child_1", "SNOMED/14140009", "child_1_1"}
392
    _create_specific_labvalue_labeler(HyperkalemiaCodeLabeler, outcome_codes)
393
394
395
def test_hypoglycemia(tmp_path: pathlib.Path):
396
    outcome_codes = {
397
        "SNOMED/120731000119103",
398
        "child_2",
399
        "child_1",
400
        "SNOMED/52767006",
401
        "SNOMED/719216001",
402
        "SNOMED/302866003",
403
        "SNOMED/267384006",
404
        "SNOMED/421725003",
405
        "SNOMED/237637005",
406
        "SNOMED/237633009",
407
        "SNOMED/190448007",
408
        "child_1_1",
409
        "SNOMED/421437000",
410
        "SNOMED/230796005",
411
        "SNOMED/84371000119108",
412
    }
413
    _create_specific_labvalue_labeler(HypoglycemiaCodeLabeler, outcome_codes)
414
415
416
def test_hyponatremia(tmp_path: pathlib.Path):
417
    outcome_codes = {"child_2", "child_1", "SNOMED/89627008", "SNOMED/267447008", "child_1_1"}
418
    _create_specific_labvalue_labeler(HyponatremiaCodeLabeler, outcome_codes)
419
420
421
def test_anemia(tmp_path: pathlib.Path):
422
    outcome_codes = {
423
        "child_1",
424
        "SNOMED/713496008",
425
        "SNOMED/691411000119101",
426
        "SNOMED/691401000119104",
427
        "SNOMED/767657005",
428
        "child_2",
429
        "SNOMED/111570005",
430
        "SNOMED/271737000",
431
        "SNOMED/713349004",
432
        "child_1_1",
433
    }
434
    _create_specific_labvalue_labeler(AnemiaCodeLabeler, outcome_codes)
435
436
437
def test_neutropenia(tmp_path: pathlib.Path):
438
    outcome_codes = {"child_1", "SNOMED/165517008", "child_1_1"}
439
    _create_specific_labvalue_labeler(NeutropeniaCodeLabeler, outcome_codes)
440
441
442
def test_aki(tmp_path: pathlib.Path):
443
    outcome_codes = {"child_2", "child_1_1", "child_1", "SNOMED/298015003", "SNOMED/14669001", "SNOMED/35455006"}
444
    _create_specific_labvalue_labeler(AKICodeLabeler, outcome_codes)