--- a +++ b/tests/integration/measures/test_calculate.py @@ -0,0 +1,422 @@ +import random +from collections import defaultdict +from datetime import date, timedelta +from unittest import mock + +import pytest + +from ehrql import months, years +from ehrql.measures import INTERVAL, Measures, get_measure_results +from ehrql.measures.calculate import MeasuresTimeout +from ehrql.tables import EventFrame, PatientFrame, Series, table + + +@table +class patients(PatientFrame): + sex = Series(str) + + +@table +class addresses(EventFrame): + date = Series(date) + region = Series(str) + + +@table +class events(EventFrame): + date = Series(date) + code = Series(str) + value = Series(int) + + +def test_get_measure_results(engine): + events_in_interval = events.where(events.date.is_during(INTERVAL)) + event_count = events_in_interval.count_for_patient() + foo_event_count = events_in_interval.where(events.code == "foo").count_for_patient() + had_event = events_in_interval.exists_for_patient() + event_value = events_in_interval.value.sum_for_patient() + region = addresses.sort_by(addresses.date).last_for_patient().region + + intervals = years(3).starting_on("2020-01-01") + measures = Measures() + + measures.define_measure( + "foo_events_by_sex", + numerator=foo_event_count, + denominator=event_count, + group_by=dict(sex=patients.sex), + intervals=intervals, + ) + measures.define_measure( + "foo_events_by_region", + numerator=foo_event_count, + denominator=event_count, + group_by=dict(region=region), + intervals=intervals, + ) + measures.define_measure( + "had_event_by_sex", + numerator=had_event, + denominator=patients.exists_for_patient(), + group_by=dict(sex=patients.sex), + intervals=intervals, + ) + measures.define_measure( + "event_value_by_region", + numerator=event_value, + denominator=patients.exists_for_patient(), + group_by=dict(region=region), + intervals=intervals, + ) + measures.define_measure( + "had_event_by_sex_and_region", + numerator=had_event, + denominator=patients.exists_for_patient(), + group_by=dict( + sex=patients.sex, + region=region, + ), + intervals=intervals, + ) + measures.define_measure( + "foo_events", + numerator=foo_event_count, + denominator=event_count, + intervals=intervals, + ) + + patient_data, address_data, event_data = generate_data(intervals) + engine.populate( + {patients: patient_data, addresses: address_data, events: event_data} + ) + + results = get_measure_results(engine.query_engine(), measures) + results = list(results) + # Verify that we don't get any duplicate rows in the results + assert len(results) == len(set(results)) + + expected = calculate_measure_results( + intervals, patient_data, address_data, event_data + ) + expected = list(expected) + # We don't care about the order of the results + assert set(results) == set(expected) + + +def test_get_measures_interval_dependent_denominator(engine): + # Test results when an interval denominator is dependent on the specific interval + # (i.e. values in other intervals affect the inclusion in this interval population) + # i.e. the union of all measure denominators will exclude some patients + intervals = years(2).starting_on("2020-01-01") + measures = Measures() + + is_female = patients.sex == "female" + had_event_in_interval = events.where( + events.date.is_during(INTERVAL) + ).exists_for_patient() + had_event_outside_interval = events.where( + events.date.is_before(INTERVAL.start_date) + | events.date.is_after(INTERVAL.end_date) + ).exists_for_patient() + measures.define_measure( + "female_by_events_outside_interval_only", + numerator=is_female, + denominator=had_event_outside_interval & ~(had_event_in_interval), + intervals=intervals, + ) + + patient_data = [ + dict(patient_id=1, sex="male"), + dict(patient_id=2, sex="female"), + dict(patient_id=3, sex="male"), + dict(patient_id=4, sex="female"), + ] + event_data = [ + # Interval 1 includes only patient 2 (female) in the population (has an event in interval 2 only) + # Interval 2 includes only patient 1 (male) in the population (has an event in interval 1 only) + dict(patient_id=1, code="abc", date=date(2020, 2, 1)), + dict(patient_id=2, code="abc", date=date(2021, 2, 1)), + # Patient 3 and 4 have events in both intervals, so aren't included in the population for + # either + dict(patient_id=3, code="abc", date=date(2020, 2, 1)), + dict(patient_id=4, code="abc", date=date(2020, 2, 1)), + dict(patient_id=3, code="abc", date=date(2021, 2, 1)), + dict(patient_id=4, code="abc", date=date(2021, 2, 1)), + ] + engine.populate({patients: patient_data, events: event_data}) + results = get_measure_results(engine.query_engine(), measures) + + expected = [ + # interval 1 has 1 female patient in the population - numerator 1, denominator 1 + ( + "female_by_events_outside_interval_only", + date(2020, 1, 1), + date(2020, 12, 31), + 1.0, + 1, + 1, + ), + # interval 2 has 1 male patient in the population - numerator 0, denominator 1 + ( + "female_by_events_outside_interval_only", + date(2021, 1, 1), + date(2021, 12, 31), + 0.0, + 0, + 1, + ), + ] + + assert set(results) == set(expected) + + +def test_get_measures_same_numerator_and_denominator(engine): + # Ensure that calculations are handled correctly when the same column + # is used as both numerator and denominator + intervals = years(2).starting_on("2020-01-01") + measures = Measures() + measures.define_measure( + "test", + numerator=patients.exists_for_patient(), + denominator=patients.exists_for_patient(), + intervals=intervals, + ) + patient_data = [dict(patient_id=1), dict(patient_id=2)] + engine.populate({patients: patient_data}) + results = set(get_measure_results(engine.query_engine(), measures)) + expected = { + ("test", date(2020, 1, 1), date(2020, 12, 31), 1.0, 2, 2), + ("test", date(2021, 1, 1), date(2021, 12, 31), 1.0, 2, 2), + } + assert results == expected + + +def test_get_measures_duplicate_group_bys(engine): + # Ensure that calculations are handled correctly when there are measures + # in the same group (sharing a denominator and intervals) with the same + # group bys. These can be handled by a single grouping set in the SQL + # query; duplicate grouping sets in the query result in duplicate + # rows in the result + events_in_interval = events.where(events.date.is_during(INTERVAL)) + event_count = events_in_interval.count_for_patient() + foo_event_count = events_in_interval.where(events.code == "foo").count_for_patient() + bar_event_count = events_in_interval.where(events.code == "bar").count_for_patient() + + intervals = years(1).starting_on("2020-01-01") + measures = Measures() + + measures.define_measure( + "foo_events", + numerator=foo_event_count, + denominator=event_count, + intervals=intervals, + ) + measures.define_measure( + "foo_events_by_sex", + numerator=foo_event_count, + denominator=event_count, + group_by=dict(sex=patients.sex), + intervals=intervals, + ) + measures.define_measure( + "bar_events", + numerator=bar_event_count, + denominator=event_count, + intervals=intervals, + ) + measures.define_measure( + "bar_events_by_sex", + numerator=bar_event_count, + denominator=event_count, + group_by=dict(sex=patients.sex), + intervals=intervals, + ) + + patient_data = [dict(patient_id=1, sex="male"), dict(patient_id=2, sex="female")] + address_data = [ + dict(patient_id=1, date=date(2020, 1, 1), region="London"), + dict(patient_id=1, date=date(2020, 1, 1), region="The North"), + ] + event_data = [ + dict(patient_id=1, date=date(2020, 2, 1), code="foo"), + dict(patient_id=1, date=date(2020, 2, 1), code="bar"), + dict(patient_id=2, date=date(2020, 2, 1), code="foo"), + dict(patient_id=2, date=date(2020, 2, 1), code="bar"), + ] + engine.populate( + {patients: patient_data, addresses: address_data, events: event_data} + ) + results = list(get_measure_results(engine.query_engine(), measures)) + # Verify that we don't get any duplicate rows in the results + assert len(results) == len(set(results)) + + expected = { + ("foo_events", date(2020, 1, 1), date(2020, 12, 31), 0.5, 2, 4, None), + ("foo_events_by_sex", date(2020, 1, 1), date(2020, 12, 31), 0.5, 1, 2, "male"), + ( + "foo_events_by_sex", + date(2020, 1, 1), + date(2020, 12, 31), + 0.5, + 1, + 2, + "female", + ), + ("bar_events", date(2020, 1, 1), date(2020, 12, 31), 0.5, 2, 4, None), + ("bar_events_by_sex", date(2020, 1, 1), date(2020, 12, 31), 0.5, 1, 2, "male"), + ( + "bar_events_by_sex", + date(2020, 1, 1), + date(2020, 12, 31), + 0.5, + 1, + 2, + "female", + ), + } + assert set(results) == expected + + +@mock.patch("ehrql.measures.calculate.time") +def test_get_measure_results_with_timeout(patched_time, in_memory_engine): + events_in_interval = events.where(events.date.is_during(INTERVAL)) + event_count = events_in_interval.count_for_patient() + foo_event_count = events_in_interval.where(events.code == "foo").count_for_patient() + + intervals = months(60).starting_on("2000-01-01") + measures = Measures() + + measures.define_measure( + "foo_events", + numerator=foo_event_count, + denominator=event_count, + intervals=intervals, + group_by=dict( + sex=patients.sex, + ), + ) + + patient_data, _, event_data = generate_data(intervals) + in_memory_engine.populate({patients: patient_data, events: event_data}) + + patched_time.time.side_effect = [0.0, 1000.0, 1000000.0] + results = get_measure_results(in_memory_engine.query_engine(), measures) + with pytest.raises(MeasuresTimeout, match="time limit"): + results = list(results) + + +def generate_data(intervals): + rnd = random.Random(20230518) + # Generate some random patients + patient_data = [ + dict( + patient_id=patient_id, + sex=rnd.choice(["male", "female"]), + ) + for patient_id in range(1, 50) + ] + # Generate some addresses (at least one) for each patient + # Make sure that address dates for the same patient are different; otherwise + # we can't be sure which region will be returned as the last + address_data = [] + interval_range = (intervals[0][0], intervals[-1][1]) + for patient in patient_data: + address_dates = set() + for _ in range(rnd.randint(1, 3)): + address_dates.add(random_date_in_interval(rnd, interval_range)) + + for address_date in address_dates: + address_data.append( + dict( + patient_id=patient["patient_id"], + date=address_date, + region=rnd.choice(["London", "The North", "The Countryside"]), + ) + ) + # For each interval and patient, generate some events (possibly zero) + event_data = [] + for interval in intervals: + for patient in patient_data: + # Choose a number of events, biased towards zero + event_count = max(rnd.randint(-10, 10), 0) + event_data.extend( + dict( + patient_id=patient["patient_id"], + code=rnd.choice(["abc", "def", "foo"]), + date=random_date_in_interval(rnd, interval), + value=rnd.randint(0, 10), + ) + for _ in range(event_count) + ) + return patient_data, address_data, event_data + + +def random_date_in_interval(rnd, interval): + days_in_interval = (interval[1] - interval[0]).days + offset = rnd.randint(0, days_in_interval) + return interval[0] + timedelta(days=offset) + + +def calculate_measure_results(intervals, patient_data, address_data, event_data): + nums = defaultdict(int) + dens = defaultdict(int) + + for interval, patient, address, events in group_events( + intervals, patient_data, address_data, event_data + ): + event_count = len(events) + foo_count = len([e for e in events if e["code"] == "foo"]) + had_event = 1 if events else 0 + event_value = sum([e["value"] for e in events], start=0) + + nums[("foo_events_by_sex", interval, patient["sex"], None)] += foo_count + dens[("foo_events_by_sex", interval, patient["sex"], None)] += event_count + nums[("foo_events_by_region", interval, None, address["region"])] += foo_count + dens[("foo_events_by_region", interval, None, address["region"])] += event_count + nums[("had_event_by_sex", interval, patient["sex"], None)] += had_event + dens[("had_event_by_sex", interval, patient["sex"], None)] += 1 + nums[("event_value_by_region", interval, None, address["region"])] += ( + event_value + ) + dens[("event_value_by_region", interval, None, address["region"])] += 1 + nums[ + ("had_event_by_sex_and_region", interval, patient["sex"], address["region"]) + ] += had_event + dens[ + ("had_event_by_sex_and_region", interval, patient["sex"], address["region"]) + ] += 1 + nums[("foo_events", interval, None, None)] += foo_count + dens[("foo_events", interval, None, None)] += event_count + + for key, numerator in nums.items(): + measure, interval, sex, region = key + denominator = dens[key] + ratio = numerator / denominator + yield ( + measure, + interval[0], + interval[1], + ratio, + numerator, + denominator, + sex, + region, + ) + + +def group_events(intervals, patient_data, address_data, event_data): + "Group events by interval and patient" + for patient in patient_data: + patient_events = [ + e for e in event_data if e["patient_id"] == patient["patient_id"] + ] + patient_addresses = sorted( + [a for a in address_data if a["patient_id"] == patient["patient_id"]], + key=lambda a: a["date"], + ) + address = patient_addresses[-1] + for interval in intervals: + interval_events = [ + e for e in patient_events if interval[0] <= e["date"] <= interval[1] + ] + yield interval, patient, address, interval_events