a b/tests/labelers/test_TimeHorizonEventLabeler.py
1
# flake8: noqa: E402
2
3
import datetime
4
import os
5
import pathlib
6
import sys
7
import warnings
8
from typing import List
9
10
import meds
11
from femr_test_tools import EventsWithLabels, run_test_for_labeler
12
13
from femr.labelers import TimeHorizon, TimeHorizonEventLabeler
14
15
16
class DummyLabeler(TimeHorizonEventLabeler):
17
    """Dummy labeler that returns True if the event's `code` is in `self.outcome_codes`."""
18
19
    def __init__(self, outcome_codes: List[int], time_horizon: TimeHorizon, allow_same_time: bool = True):
20
        self.outcome_codes: List[str] = [str(a) for a in outcome_codes]
21
        self.time_horizon: TimeHorizon = time_horizon
22
        self.allow_same_time = allow_same_time
23
24
    def allow_same_time_labels(self) -> bool:
25
        return self.allow_same_time
26
27
    def get_prediction_times(self, patient: meds.Patient) -> List[datetime.datetime]:
28
        return sorted(list({e["time"] for e in patient["events"]}))
29
30
    def get_time_horizon(self) -> TimeHorizon:
31
        return self.time_horizon
32
33
    def get_outcome_times(self, patient: meds.Patient) -> List[datetime.datetime]:
34
        times: List[datetime.datetime] = []
35
        for e in patient["events"]:
36
            for m in e["measurements"]:
37
                if m["code"] in self.outcome_codes:
38
                    times.append(e["time"])
39
        return times
40
41
42
def test_no_outcomes(tmp_path: pathlib.Path):
43
    # No outcomes occur in this patient's timeline
44
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180))
45
    labeler = DummyLabeler([100], time_horizon)
46
    events_with_labels: EventsWithLabels = [
47
        # fmt: off
48
        (((2015, 1, 3), 2, None), "duplicate"),
49
        (((2015, 1, 3), 1, None), "duplicate"),
50
        (((2015, 1, 3), 3, None), False),
51
        (((2015, 10, 5), 1, None), False),
52
        (((2018, 1, 3), 2, None), False),
53
        (((2018, 3, 3), 1, None), False),
54
        (((2018, 5, 3), 2, None), False),
55
        (((2018, 5, 3, 11), 1, None), False),
56
        (((2018, 5, 4), 1, None), False),
57
        (((2018, 12, 4), 1, None), "out of range"),
58
        # fmt: on
59
    ]
60
    run_test_for_labeler(labeler, events_with_labels, help_text="test_no_outcomes")
61
62
63
def test_horizon_0_180_days(tmp_path: pathlib.Path):
64
    # (0, 180) days
65
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180))
66
    labeler = DummyLabeler([2], time_horizon)
67
    events_with_labels: EventsWithLabels = [
68
        # fmt: off
69
        (((2015, 1, 3), 2, None), "duplicate"),
70
        (((2015, 1, 3), 1, None), "duplicate"),
71
        (((2015, 1, 3), 3, None), True),
72
        (((2015, 10, 5), 1, None), False),
73
        (((2018, 1, 3), 2, None), True),
74
        (((2018, 3, 3), 1, None), True),
75
        (((2018, 5, 3), 2, None), True),
76
        (((2018, 5, 3, 11), 1, None), False),
77
        (((2018, 5, 4), 1, None), False),
78
        (((2018, 12, 4), 1, None), "out of range"),
79
        # fmt: on
80
    ]
81
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_180_days")
82
83
84
def test_horizon_0_180_days_no_same(tmp_path: pathlib.Path):
85
    # (0, 180) days
86
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180))
87
    labeler = DummyLabeler([2], time_horizon, allow_same_time=False)
88
    events_with_labels: EventsWithLabels = [
89
        # fmt: off
90
        (((2015, 1, 3), 2, None), "duplicate"),
91
        (((2015, 1, 3), 1, None), "duplicate"),
92
        (((2015, 1, 3), 3, None), "same"),
93
        (((2015, 10, 5), 1, None), False),
94
        (((2018, 1, 3), 2, None), "same"),
95
        (((2018, 3, 3), 1, None), True),
96
        (((2018, 5, 3), 2, None), "same"),
97
        (((2018, 5, 3, 11), 1, None), False),
98
        (((2018, 5, 4), 1, None), False),
99
        (((2018, 12, 4), 1, None), "out of range"),
100
        # fmt: on
101
    ]
102
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_180_days")
103
104
105
def test_horizon_1_180_days(tmp_path: pathlib.Path):
106
    # (1, 180) days
107
    time_horizon = TimeHorizon(datetime.timedelta(days=1), datetime.timedelta(days=180))
108
    labeler = DummyLabeler([2], time_horizon)
109
    events_with_labels: EventsWithLabels = [
110
        # fmt: off
111
        (((2015, 1, 3), 2, None), "duplicate"),
112
        (((2015, 1, 3), 1, None), "duplicate"),
113
        (((2015, 1, 3), 3, None), False),
114
        (((2015, 10, 5), 1, None), False),
115
        (((2018, 1, 3), 2, None), True),
116
        (((2018, 3, 3), 1, None), True),
117
        (((2018, 5, 3), 2, None), False),
118
        (((2018, 5, 3, 11), 1, None), False),
119
        (((2018, 5, 4), 1, None), False),
120
        (((2018, 12, 4), 1, None), "out of range"),
121
        # fmt: on
122
    ]
123
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_1_180_days")
124
125
126
def test_horizon_180_365_days(tmp_path: pathlib.Path):
127
    # (180, 365) days
128
    time_horizon = TimeHorizon(datetime.timedelta(days=180), datetime.timedelta(days=365))
129
    labeler = DummyLabeler([2], time_horizon)
130
    events_with_labels: EventsWithLabels = [
131
        # fmt: off
132
        (((2000, 1, 3), 2, None), True),
133
        (((2000, 10, 5), 2, None), False),
134
        (((2002, 1, 5), 2, None), True),
135
        (((2002, 3, 1), 1, None), True),
136
        (((2002, 4, 5), 3, None), True),
137
        (((2002, 4, 12), 1, None), True),
138
        (((2002, 12, 5), 2, None), False),
139
        (((2002, 12, 10), 1, None), False),
140
        (((2004, 1, 10), 2, None), False),
141
        (((2008, 1, 10), 2, None), "out of range"),
142
        # fmt: on
143
    ]
144
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_180_365_days")
145
146
147
def test_horizon_0_0_days(tmp_path: pathlib.Path):
148
    # (0, 0) days
149
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=0))
150
    labeler = DummyLabeler([2], time_horizon)
151
    events_with_labels: EventsWithLabels = [
152
        # fmt: off
153
        (((2015, 1, 3), 2, None), "duplicate"),
154
        (((2015, 1, 3), 1, None), True),
155
        (((2015, 1, 4), 1, None), False),
156
        (((2015, 1, 5), 2, None), True),
157
        (((2015, 1, 5, 10), 1, None), False),
158
        (((2015, 1, 6), 2, None), True),
159
        # fmt: on
160
    ]
161
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_0_days")
162
163
164
def test_horizon_10_10_days(tmp_path: pathlib.Path):
165
    # (10, 10) days
166
    time_horizon = TimeHorizon(datetime.timedelta(days=10), datetime.timedelta(days=10))
167
    labeler = DummyLabeler([2], time_horizon)
168
    events_with_labels: EventsWithLabels = [
169
        # fmt: off
170
        (((2015, 1, 3), 2, None), False),
171
        (((2015, 1, 13), 1, None), True),
172
        (((2015, 1, 23), 2, None), True),
173
        (((2015, 2, 2), 2, None), False),
174
        (((2015, 3, 10), 1, None), True),
175
        (((2015, 3, 20), 2, None), False),
176
        (((2015, 3, 29), 2, None), "out of range"),
177
        (((2015, 3, 30), 1, None), "out of range"),
178
        # fmt: on
179
    ]
180
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_10_10_days")
181
182
183
def test_horizon_0_1000000_days(tmp_path: pathlib.Path):
184
    # (0, 1000000) days
185
    time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=1000000))
186
    labeler = DummyLabeler([2], time_horizon)
187
    events_with_labels: EventsWithLabels = [
188
        # fmt: off
189
        (((2000, 1, 3), 2, None), True),
190
        (((2001, 10, 5), 1, None), True),
191
        (((2020, 10, 5), 2, None), True),
192
        (((2021, 10, 5), 1, None), True),
193
        (((2050, 1, 10), 2, None), True),
194
        (((2051, 1, 10), 1, None), False),
195
        (((5000, 1, 10), 1, None), "out of range"),
196
        # fmt: on
197
    ]
198
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_1000000_days")
199
200
201
def test_horizon_5_10_hours(tmp_path: pathlib.Path):
202
    # (5 hours, 10.5 hours)
203
    time_horizon = TimeHorizon(datetime.timedelta(hours=5), datetime.timedelta(hours=10, minutes=30))
204
    labeler = DummyLabeler([2], time_horizon)
205
    events_with_labels: EventsWithLabels = [
206
        # fmt: off
207
        (((2015, 1, 1, 0, 0), 1, None), True),
208
        (((2015, 1, 1, 10, 29), 2, None), False),
209
        (((2015, 1, 1, 10, 30), 1, None), False),
210
        (((2015, 1, 1, 10, 31), 1, None), False),
211
        #
212
        (((2016, 1, 1, 0, 0), 1, None), True),
213
        (((2016, 1, 1, 10, 29), 1, None), False),
214
        (((2016, 1, 1, 10, 30), 2, None), False),
215
        (((2016, 1, 1, 10, 31), 1, None), False),
216
        #
217
        (((2017, 1, 1, 0, 0), 1, None), False),
218
        (((2017, 1, 1, 10, 29), 1, None), False),
219
        (((2017, 1, 1, 10, 30), 1, None), False),
220
        (((2017, 1, 1, 10, 31), 2, None), False),
221
        #
222
        (((2018, 1, 1, 0, 0), 1, None), False),
223
        (((2018, 1, 1, 4, 59, 59), 2, None), False),
224
        (((2018, 1, 1, 5), 1, None), False),
225
        #
226
        (((2019, 1, 1, 0, 0), 1, None), True),
227
        (((2019, 1, 1, 4, 59, 59), 1, None), "out of range"),
228
        (((2019, 1, 1, 5), 2, None), "out of range"),
229
        # fmt: on
230
    ]
231
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_5_10_hours")
232
233
234
def test_horizon_infinite(tmp_path: pathlib.Path):
235
    # Infinite horizon
236
    time_horizon = TimeHorizon(
237
        datetime.timedelta(days=10),
238
        None,
239
    )
240
    labeler = DummyLabeler([2], time_horizon)
241
    events_with_labels: EventsWithLabels = [
242
        # fmt: off
243
        (((1950, 1, 3), 1, None), True),
244
        (((2000, 1, 3), 1, None), True),
245
        (((2001, 10, 5), 1, None), True),
246
        (((2020, 10, 5), 1, None), True),
247
        (((2021, 10, 5), 1, None), True),
248
        (((2050, 1, 10), 2, None), True),
249
        (((2050, 1, 20), 2, None), False),
250
        (((2051, 1, 10), 1, None), False),
251
        (((5000, 1, 10), 1, None), False),
252
        # fmt: on
253
    ]
254
    run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_infinite")