a b/tests/femr_test_tools.py
1
import datetime
2
from typing import Any, List, Optional, Tuple, Union
3
4
import datasets
5
import meds
6
7
from femr.labelers import Labeler
8
9
# 2nd elem of tuple -- 'skip' means no label, None means censored
10
EventsWithLabels = List[Tuple[Tuple[Tuple, int, Any], Union[bool, str]]]
11
12
13
DUMMY_EVENTS = [
14
    ((1995, 1, 3), meds.birth_code, None),
15
    ((2010, 1, 1), 1, "test_value"),
16
    ((2010, 1, 1), 1, "test_value"),
17
    ((2010, 1, 5), 2, 1),
18
    ((2010, 6, 5), 3, None),
19
    ((2010, 8, 5), 2, None),
20
    ((2011, 7, 5), 2, None),
21
    ((2012, 10, 5), 3, None),
22
    ((2015, 6, 5, 0), 2, None),
23
    ((2015, 6, 5, 10, 10), 2, None),
24
    ((2015, 6, 15, 11), 3, None),
25
    ((2016, 1, 1), 2, None),
26
    ((2016, 3, 1, 10, 10, 10), 4, None),
27
]
28
29
NUM_EVENTS = len(DUMMY_EVENTS)
30
NUM_PATIENTS = 10
31
32
33
def create_patients_dataset(num_patients: int, events: List[Tuple[Tuple, int, Any]] = DUMMY_EVENTS) -> datasets.Dataset:
34
    """Creates a list of patients, each with the same events contained in `events`"""
35
36
    converted_events: List[meds.Event] = []
37
38
    for event in events:
39
        if isinstance(event[1], int):
40
            code = str(event[1])
41
        else:
42
            code = event[1]
43
        if event[2] is None:
44
            val = {}
45
        elif isinstance(event[2], str):
46
            val = {"text_value": event[2]}
47
        else:
48
            val = {"numeric_value": event[2]}
49
        converted_events.append({"time": datetime.datetime(*event[0]), "measurements": [{"code": code, **val}]})
50
51
    return datasets.Dataset.from_dict(
52
        {"patient_id": list(range(num_patients)), "events": [converted_events for _ in range(num_patients)]}
53
    )
54
55
56
def assert_labels_are_accurate(
57
    labeled_patients: List[meds.Label],
58
    patient_id: int,
59
    true_labels: List[Tuple[datetime.datetime, Optional[bool]]],
60
    help_text: str = "",
61
):
62
    """Passes if the labels in `labeled_patients` for `patient_id` exactly match the labels in `true_labels`."""
63
    generated_labels: List[meds.Label] = [a for a in labeled_patients if a["patient_id"] == patient_id]
64
    # Check that length of lists of labels are the same
65
66
    assert len(generated_labels) == len(
67
        true_labels
68
    ), f"len(generated): {len(generated_labels)} != len(expected): {len(true_labels)} | {help_text}"
69
    # Check that value of labels are the same
70
    for idx, (label, true_label) in enumerate(zip(generated_labels, true_labels)):
71
        assert label["boolean_value"] == true_label[1] and label["prediction_time"] == true_label[0], (
72
            f"patient_id={patient_id}, label_idx={idx}, label={label}  |  "
73
            f"{label} (Assigned) != {true_label} (Expected)  |  "
74
            f"{help_text}"
75
        )
76
77
78
def run_test_for_labeler(
79
    labeler: Labeler,
80
    events_with_labels: EventsWithLabels,
81
    true_outcome_times: Optional[List[datetime.datetime]] = None,
82
    true_prediction_times: Optional[List[datetime.datetime]] = None,
83
    help_text: str = "",
84
) -> None:
85
    patients: datasets.Database = create_patients_dataset(10, [x[0] for x in events_with_labels])
86
    true_labels: List[Tuple[datetime.datetime, Optional[bool]]] = [
87
        (datetime.datetime(*x[0][0]), x[1]) for x in events_with_labels if isinstance(x[1], bool)
88
    ]
89
    if true_prediction_times is not None:
90
        # If manually specified prediction times, adjust labels from occurring at `event.start`
91
        # e.g. we may make predictions at `event.end` or `event.start + 1 day`
92
        true_labels = [(tp, tl[1]) for (tl, tp) in zip(true_labels, true_prediction_times)]
93
    labeled_patients: List[meds.Label] = labeler.apply(patients)
94
95
    # Check accuracy of Labels
96
    for patient in patients:
97
        assert_labels_are_accurate(
98
            labeled_patients,
99
            patient["patient_id"],
100
            true_labels,
101
            help_text=help_text,
102
        )