Switch to side-by-side view

--- a
+++ b/tests/integration/measures/test_calculate.py
@@ -0,0 +1,422 @@
+import random
+from collections import defaultdict
+from datetime import date, timedelta
+from unittest import mock
+
+import pytest
+
+from ehrql import months, years
+from ehrql.measures import INTERVAL, Measures, get_measure_results
+from ehrql.measures.calculate import MeasuresTimeout
+from ehrql.tables import EventFrame, PatientFrame, Series, table
+
+
+@table
+class patients(PatientFrame):
+    sex = Series(str)
+
+
+@table
+class addresses(EventFrame):
+    date = Series(date)
+    region = Series(str)
+
+
+@table
+class events(EventFrame):
+    date = Series(date)
+    code = Series(str)
+    value = Series(int)
+
+
+def test_get_measure_results(engine):
+    events_in_interval = events.where(events.date.is_during(INTERVAL))
+    event_count = events_in_interval.count_for_patient()
+    foo_event_count = events_in_interval.where(events.code == "foo").count_for_patient()
+    had_event = events_in_interval.exists_for_patient()
+    event_value = events_in_interval.value.sum_for_patient()
+    region = addresses.sort_by(addresses.date).last_for_patient().region
+
+    intervals = years(3).starting_on("2020-01-01")
+    measures = Measures()
+
+    measures.define_measure(
+        "foo_events_by_sex",
+        numerator=foo_event_count,
+        denominator=event_count,
+        group_by=dict(sex=patients.sex),
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "foo_events_by_region",
+        numerator=foo_event_count,
+        denominator=event_count,
+        group_by=dict(region=region),
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "had_event_by_sex",
+        numerator=had_event,
+        denominator=patients.exists_for_patient(),
+        group_by=dict(sex=patients.sex),
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "event_value_by_region",
+        numerator=event_value,
+        denominator=patients.exists_for_patient(),
+        group_by=dict(region=region),
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "had_event_by_sex_and_region",
+        numerator=had_event,
+        denominator=patients.exists_for_patient(),
+        group_by=dict(
+            sex=patients.sex,
+            region=region,
+        ),
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "foo_events",
+        numerator=foo_event_count,
+        denominator=event_count,
+        intervals=intervals,
+    )
+
+    patient_data, address_data, event_data = generate_data(intervals)
+    engine.populate(
+        {patients: patient_data, addresses: address_data, events: event_data}
+    )
+
+    results = get_measure_results(engine.query_engine(), measures)
+    results = list(results)
+    # Verify that we don't get any duplicate rows in the results
+    assert len(results) == len(set(results))
+
+    expected = calculate_measure_results(
+        intervals, patient_data, address_data, event_data
+    )
+    expected = list(expected)
+    # We don't care about the order of the results
+    assert set(results) == set(expected)
+
+
+def test_get_measures_interval_dependent_denominator(engine):
+    # Test results when an interval denominator is dependent on the specific interval
+    # (i.e. values in other intervals affect the inclusion in this interval population)
+    # i.e. the union of all measure denominators will exclude some patients
+    intervals = years(2).starting_on("2020-01-01")
+    measures = Measures()
+
+    is_female = patients.sex == "female"
+    had_event_in_interval = events.where(
+        events.date.is_during(INTERVAL)
+    ).exists_for_patient()
+    had_event_outside_interval = events.where(
+        events.date.is_before(INTERVAL.start_date)
+        | events.date.is_after(INTERVAL.end_date)
+    ).exists_for_patient()
+    measures.define_measure(
+        "female_by_events_outside_interval_only",
+        numerator=is_female,
+        denominator=had_event_outside_interval & ~(had_event_in_interval),
+        intervals=intervals,
+    )
+
+    patient_data = [
+        dict(patient_id=1, sex="male"),
+        dict(patient_id=2, sex="female"),
+        dict(patient_id=3, sex="male"),
+        dict(patient_id=4, sex="female"),
+    ]
+    event_data = [
+        # Interval 1 includes only patient 2 (female) in the population (has an event in interval 2 only)
+        # Interval 2 includes only patient 1 (male) in the population (has an event in interval 1 only)
+        dict(patient_id=1, code="abc", date=date(2020, 2, 1)),
+        dict(patient_id=2, code="abc", date=date(2021, 2, 1)),
+        # Patient 3 and 4 have events in both intervals, so aren't included in the population for
+        # either
+        dict(patient_id=3, code="abc", date=date(2020, 2, 1)),
+        dict(patient_id=4, code="abc", date=date(2020, 2, 1)),
+        dict(patient_id=3, code="abc", date=date(2021, 2, 1)),
+        dict(patient_id=4, code="abc", date=date(2021, 2, 1)),
+    ]
+    engine.populate({patients: patient_data, events: event_data})
+    results = get_measure_results(engine.query_engine(), measures)
+
+    expected = [
+        # interval 1 has 1 female patient in the population - numerator 1, denominator 1
+        (
+            "female_by_events_outside_interval_only",
+            date(2020, 1, 1),
+            date(2020, 12, 31),
+            1.0,
+            1,
+            1,
+        ),
+        # interval 2 has 1 male patient in the population - numerator 0, denominator 1
+        (
+            "female_by_events_outside_interval_only",
+            date(2021, 1, 1),
+            date(2021, 12, 31),
+            0.0,
+            0,
+            1,
+        ),
+    ]
+
+    assert set(results) == set(expected)
+
+
+def test_get_measures_same_numerator_and_denominator(engine):
+    # Ensure that calculations are handled correctly when the same column
+    # is used as both numerator and denominator
+    intervals = years(2).starting_on("2020-01-01")
+    measures = Measures()
+    measures.define_measure(
+        "test",
+        numerator=patients.exists_for_patient(),
+        denominator=patients.exists_for_patient(),
+        intervals=intervals,
+    )
+    patient_data = [dict(patient_id=1), dict(patient_id=2)]
+    engine.populate({patients: patient_data})
+    results = set(get_measure_results(engine.query_engine(), measures))
+    expected = {
+        ("test", date(2020, 1, 1), date(2020, 12, 31), 1.0, 2, 2),
+        ("test", date(2021, 1, 1), date(2021, 12, 31), 1.0, 2, 2),
+    }
+    assert results == expected
+
+
+def test_get_measures_duplicate_group_bys(engine):
+    # Ensure that calculations are handled correctly when there are measures
+    # in the same group (sharing a denominator and intervals) with the same
+    # group bys. These can be handled by a single grouping set in the SQL
+    # query; duplicate grouping sets in the query result in duplicate
+    # rows in the result
+    events_in_interval = events.where(events.date.is_during(INTERVAL))
+    event_count = events_in_interval.count_for_patient()
+    foo_event_count = events_in_interval.where(events.code == "foo").count_for_patient()
+    bar_event_count = events_in_interval.where(events.code == "bar").count_for_patient()
+
+    intervals = years(1).starting_on("2020-01-01")
+    measures = Measures()
+
+    measures.define_measure(
+        "foo_events",
+        numerator=foo_event_count,
+        denominator=event_count,
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "foo_events_by_sex",
+        numerator=foo_event_count,
+        denominator=event_count,
+        group_by=dict(sex=patients.sex),
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "bar_events",
+        numerator=bar_event_count,
+        denominator=event_count,
+        intervals=intervals,
+    )
+    measures.define_measure(
+        "bar_events_by_sex",
+        numerator=bar_event_count,
+        denominator=event_count,
+        group_by=dict(sex=patients.sex),
+        intervals=intervals,
+    )
+
+    patient_data = [dict(patient_id=1, sex="male"), dict(patient_id=2, sex="female")]
+    address_data = [
+        dict(patient_id=1, date=date(2020, 1, 1), region="London"),
+        dict(patient_id=1, date=date(2020, 1, 1), region="The North"),
+    ]
+    event_data = [
+        dict(patient_id=1, date=date(2020, 2, 1), code="foo"),
+        dict(patient_id=1, date=date(2020, 2, 1), code="bar"),
+        dict(patient_id=2, date=date(2020, 2, 1), code="foo"),
+        dict(patient_id=2, date=date(2020, 2, 1), code="bar"),
+    ]
+    engine.populate(
+        {patients: patient_data, addresses: address_data, events: event_data}
+    )
+    results = list(get_measure_results(engine.query_engine(), measures))
+    # Verify that we don't get any duplicate rows in the results
+    assert len(results) == len(set(results))
+
+    expected = {
+        ("foo_events", date(2020, 1, 1), date(2020, 12, 31), 0.5, 2, 4, None),
+        ("foo_events_by_sex", date(2020, 1, 1), date(2020, 12, 31), 0.5, 1, 2, "male"),
+        (
+            "foo_events_by_sex",
+            date(2020, 1, 1),
+            date(2020, 12, 31),
+            0.5,
+            1,
+            2,
+            "female",
+        ),
+        ("bar_events", date(2020, 1, 1), date(2020, 12, 31), 0.5, 2, 4, None),
+        ("bar_events_by_sex", date(2020, 1, 1), date(2020, 12, 31), 0.5, 1, 2, "male"),
+        (
+            "bar_events_by_sex",
+            date(2020, 1, 1),
+            date(2020, 12, 31),
+            0.5,
+            1,
+            2,
+            "female",
+        ),
+    }
+    assert set(results) == expected
+
+
+@mock.patch("ehrql.measures.calculate.time")
+def test_get_measure_results_with_timeout(patched_time, in_memory_engine):
+    events_in_interval = events.where(events.date.is_during(INTERVAL))
+    event_count = events_in_interval.count_for_patient()
+    foo_event_count = events_in_interval.where(events.code == "foo").count_for_patient()
+
+    intervals = months(60).starting_on("2000-01-01")
+    measures = Measures()
+
+    measures.define_measure(
+        "foo_events",
+        numerator=foo_event_count,
+        denominator=event_count,
+        intervals=intervals,
+        group_by=dict(
+            sex=patients.sex,
+        ),
+    )
+
+    patient_data, _, event_data = generate_data(intervals)
+    in_memory_engine.populate({patients: patient_data, events: event_data})
+
+    patched_time.time.side_effect = [0.0, 1000.0, 1000000.0]
+    results = get_measure_results(in_memory_engine.query_engine(), measures)
+    with pytest.raises(MeasuresTimeout, match="time limit"):
+        results = list(results)
+
+
+def generate_data(intervals):
+    rnd = random.Random(20230518)
+    # Generate some random patients
+    patient_data = [
+        dict(
+            patient_id=patient_id,
+            sex=rnd.choice(["male", "female"]),
+        )
+        for patient_id in range(1, 50)
+    ]
+    # Generate some addresses (at least one) for each patient
+    # Make sure that address dates for the same patient are different; otherwise
+    # we can't be sure which region will be returned as the last
+    address_data = []
+    interval_range = (intervals[0][0], intervals[-1][1])
+    for patient in patient_data:
+        address_dates = set()
+        for _ in range(rnd.randint(1, 3)):
+            address_dates.add(random_date_in_interval(rnd, interval_range))
+
+        for address_date in address_dates:
+            address_data.append(
+                dict(
+                    patient_id=patient["patient_id"],
+                    date=address_date,
+                    region=rnd.choice(["London", "The North", "The Countryside"]),
+                )
+            )
+    # For each interval and patient, generate some events (possibly zero)
+    event_data = []
+    for interval in intervals:
+        for patient in patient_data:
+            # Choose a number of events, biased towards zero
+            event_count = max(rnd.randint(-10, 10), 0)
+            event_data.extend(
+                dict(
+                    patient_id=patient["patient_id"],
+                    code=rnd.choice(["abc", "def", "foo"]),
+                    date=random_date_in_interval(rnd, interval),
+                    value=rnd.randint(0, 10),
+                )
+                for _ in range(event_count)
+            )
+    return patient_data, address_data, event_data
+
+
+def random_date_in_interval(rnd, interval):
+    days_in_interval = (interval[1] - interval[0]).days
+    offset = rnd.randint(0, days_in_interval)
+    return interval[0] + timedelta(days=offset)
+
+
+def calculate_measure_results(intervals, patient_data, address_data, event_data):
+    nums = defaultdict(int)
+    dens = defaultdict(int)
+
+    for interval, patient, address, events in group_events(
+        intervals, patient_data, address_data, event_data
+    ):
+        event_count = len(events)
+        foo_count = len([e for e in events if e["code"] == "foo"])
+        had_event = 1 if events else 0
+        event_value = sum([e["value"] for e in events], start=0)
+
+        nums[("foo_events_by_sex", interval, patient["sex"], None)] += foo_count
+        dens[("foo_events_by_sex", interval, patient["sex"], None)] += event_count
+        nums[("foo_events_by_region", interval, None, address["region"])] += foo_count
+        dens[("foo_events_by_region", interval, None, address["region"])] += event_count
+        nums[("had_event_by_sex", interval, patient["sex"], None)] += had_event
+        dens[("had_event_by_sex", interval, patient["sex"], None)] += 1
+        nums[("event_value_by_region", interval, None, address["region"])] += (
+            event_value
+        )
+        dens[("event_value_by_region", interval, None, address["region"])] += 1
+        nums[
+            ("had_event_by_sex_and_region", interval, patient["sex"], address["region"])
+        ] += had_event
+        dens[
+            ("had_event_by_sex_and_region", interval, patient["sex"], address["region"])
+        ] += 1
+        nums[("foo_events", interval, None, None)] += foo_count
+        dens[("foo_events", interval, None, None)] += event_count
+
+    for key, numerator in nums.items():
+        measure, interval, sex, region = key
+        denominator = dens[key]
+        ratio = numerator / denominator
+        yield (
+            measure,
+            interval[0],
+            interval[1],
+            ratio,
+            numerator,
+            denominator,
+            sex,
+            region,
+        )
+
+
+def group_events(intervals, patient_data, address_data, event_data):
+    "Group events by interval and patient"
+    for patient in patient_data:
+        patient_events = [
+            e for e in event_data if e["patient_id"] == patient["patient_id"]
+        ]
+        patient_addresses = sorted(
+            [a for a in address_data if a["patient_id"] == patient["patient_id"]],
+            key=lambda a: a["date"],
+        )
+        address = patient_addresses[-1]
+        for interval in intervals:
+            interval_events = [
+                e for e in patient_events if interval[0] <= e["date"] <= interval[1]
+            ]
+            yield interval, patient, address, interval_events