|
a |
|
b/tests/femr_test_tools.py |
|
|
1 |
import datetime |
|
|
2 |
from typing import Any, List, Optional, Tuple, Union |
|
|
3 |
|
|
|
4 |
import datasets |
|
|
5 |
import meds |
|
|
6 |
|
|
|
7 |
from femr.labelers import Labeler |
|
|
8 |
|
|
|
9 |
# 2nd elem of tuple -- 'skip' means no label, None means censored |
|
|
10 |
EventsWithLabels = List[Tuple[Tuple[Tuple, int, Any], Union[bool, str]]] |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
DUMMY_EVENTS = [ |
|
|
14 |
((1995, 1, 3), meds.birth_code, None), |
|
|
15 |
((2010, 1, 1), 1, "test_value"), |
|
|
16 |
((2010, 1, 1), 1, "test_value"), |
|
|
17 |
((2010, 1, 5), 2, 1), |
|
|
18 |
((2010, 6, 5), 3, None), |
|
|
19 |
((2010, 8, 5), 2, None), |
|
|
20 |
((2011, 7, 5), 2, None), |
|
|
21 |
((2012, 10, 5), 3, None), |
|
|
22 |
((2015, 6, 5, 0), 2, None), |
|
|
23 |
((2015, 6, 5, 10, 10), 2, None), |
|
|
24 |
((2015, 6, 15, 11), 3, None), |
|
|
25 |
((2016, 1, 1), 2, None), |
|
|
26 |
((2016, 3, 1, 10, 10, 10), 4, None), |
|
|
27 |
] |
|
|
28 |
|
|
|
29 |
NUM_EVENTS = len(DUMMY_EVENTS) |
|
|
30 |
NUM_PATIENTS = 10 |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
def create_patients_dataset(num_patients: int, events: List[Tuple[Tuple, int, Any]] = DUMMY_EVENTS) -> datasets.Dataset: |
|
|
34 |
"""Creates a list of patients, each with the same events contained in `events`""" |
|
|
35 |
|
|
|
36 |
converted_events: List[meds.Event] = [] |
|
|
37 |
|
|
|
38 |
for event in events: |
|
|
39 |
if isinstance(event[1], int): |
|
|
40 |
code = str(event[1]) |
|
|
41 |
else: |
|
|
42 |
code = event[1] |
|
|
43 |
if event[2] is None: |
|
|
44 |
val = {} |
|
|
45 |
elif isinstance(event[2], str): |
|
|
46 |
val = {"text_value": event[2]} |
|
|
47 |
else: |
|
|
48 |
val = {"numeric_value": event[2]} |
|
|
49 |
converted_events.append({"time": datetime.datetime(*event[0]), "measurements": [{"code": code, **val}]}) |
|
|
50 |
|
|
|
51 |
return datasets.Dataset.from_dict( |
|
|
52 |
{"patient_id": list(range(num_patients)), "events": [converted_events for _ in range(num_patients)]} |
|
|
53 |
) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
def assert_labels_are_accurate( |
|
|
57 |
labeled_patients: List[meds.Label], |
|
|
58 |
patient_id: int, |
|
|
59 |
true_labels: List[Tuple[datetime.datetime, Optional[bool]]], |
|
|
60 |
help_text: str = "", |
|
|
61 |
): |
|
|
62 |
"""Passes if the labels in `labeled_patients` for `patient_id` exactly match the labels in `true_labels`.""" |
|
|
63 |
generated_labels: List[meds.Label] = [a for a in labeled_patients if a["patient_id"] == patient_id] |
|
|
64 |
# Check that length of lists of labels are the same |
|
|
65 |
|
|
|
66 |
assert len(generated_labels) == len( |
|
|
67 |
true_labels |
|
|
68 |
), f"len(generated): {len(generated_labels)} != len(expected): {len(true_labels)} | {help_text}" |
|
|
69 |
# Check that value of labels are the same |
|
|
70 |
for idx, (label, true_label) in enumerate(zip(generated_labels, true_labels)): |
|
|
71 |
assert label["boolean_value"] == true_label[1] and label["prediction_time"] == true_label[0], ( |
|
|
72 |
f"patient_id={patient_id}, label_idx={idx}, label={label} | " |
|
|
73 |
f"{label} (Assigned) != {true_label} (Expected) | " |
|
|
74 |
f"{help_text}" |
|
|
75 |
) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
def run_test_for_labeler( |
|
|
79 |
labeler: Labeler, |
|
|
80 |
events_with_labels: EventsWithLabels, |
|
|
81 |
true_outcome_times: Optional[List[datetime.datetime]] = None, |
|
|
82 |
true_prediction_times: Optional[List[datetime.datetime]] = None, |
|
|
83 |
help_text: str = "", |
|
|
84 |
) -> None: |
|
|
85 |
patients: datasets.Database = create_patients_dataset(10, [x[0] for x in events_with_labels]) |
|
|
86 |
true_labels: List[Tuple[datetime.datetime, Optional[bool]]] = [ |
|
|
87 |
(datetime.datetime(*x[0][0]), x[1]) for x in events_with_labels if isinstance(x[1], bool) |
|
|
88 |
] |
|
|
89 |
if true_prediction_times is not None: |
|
|
90 |
# If manually specified prediction times, adjust labels from occurring at `event.start` |
|
|
91 |
# e.g. we may make predictions at `event.end` or `event.start + 1 day` |
|
|
92 |
true_labels = [(tp, tl[1]) for (tl, tp) in zip(true_labels, true_prediction_times)] |
|
|
93 |
labeled_patients: List[meds.Label] = labeler.apply(patients) |
|
|
94 |
|
|
|
95 |
# Check accuracy of Labels |
|
|
96 |
for patient in patients: |
|
|
97 |
assert_labels_are_accurate( |
|
|
98 |
labeled_patients, |
|
|
99 |
patient["patient_id"], |
|
|
100 |
true_labels, |
|
|
101 |
help_text=help_text, |
|
|
102 |
) |