--- a +++ b/tests/unit/measures/test_measures.py @@ -0,0 +1,304 @@ +import re +from datetime import date + +import pytest + +from ehrql import Error, Measures, months +from ehrql.measures.measures import Measure, create_measures +from ehrql.tables import PatientFrame, Series, table + + +@table +class patients(PatientFrame): + is_interesting = Series(bool) + score = Series(int) + category = Series(str) + style = Series(str) + + +def test_create_measures(): + assert isinstance(create_measures(), Measures) + + +def test_define_measures(): + intervals = [ + (date(2021, 1, 1), date(2021, 1, 31)), + (date(2021, 2, 1), date(2021, 2, 28)), + ] + + measures = Measures() + + measures.define_defaults( + numerator=patients.score, + denominator=patients.is_interesting, + intervals=intervals, + ) + + measures.define_measure( + name="test_1", + group_by={"category": patients.category}, + ) + measures.define_measure( + name="test_2", + group_by={"style": patients.style}, + ) + + assert measures.disclosure_control_config.enabled + + assert len(measures) == 2 + + assert list(measures) == [ + Measure( + name="test_1", + numerator=patients.score._qm_node, + denominator=patients.is_interesting._qm_node, + group_by={ + "category": patients.category._qm_node, + }, + intervals=tuple(intervals), + ), + Measure( + name="test_2", + numerator=patients.score._qm_node, + denominator=patients.is_interesting._qm_node, + group_by={ + "style": patients.style._qm_node, + }, + intervals=tuple(intervals), + ), + ] + + +def test_define_measures_with_default_group_by(): + # Regression test for #1423. + + intervals = [ + (date(2021, 1, 1), date(2021, 1, 31)), + (date(2021, 2, 1), date(2021, 2, 28)), + ] + + measures = Measures() + + measures.define_defaults( + numerator=patients.score, + denominator=patients.is_interesting, + intervals=intervals, + group_by={"category": patients.category}, + ) + + measures.define_measure( + name="test_1", + ) + measures.define_measure( + name="test_2", + ) + + assert measures.disclosure_control_config.enabled + + assert len(measures) == 2 + + assert list(measures) == [ + Measure( + name="test_1", + numerator=patients.score._qm_node, + denominator=patients.is_interesting._qm_node, + group_by={ + "category": patients.category._qm_node, + }, + intervals=tuple(intervals), + ), + Measure( + name="test_2", + numerator=patients.score._qm_node, + denominator=patients.is_interesting._qm_node, + group_by={ + "category": patients.category._qm_node, + }, + intervals=tuple(intervals), + ), + ] + + +def test_cannot_redefine_defaults(): + measures = Measures() + measures.define_defaults(numerator=patients.score) + with pytest.raises(Error, match="Defaults already set"): + measures.define_defaults(numerator=patients.is_interesting) + + +def test_must_define_all_attributes(): + measures = Measures() + measures.define_defaults(numerator=patients.score) + with pytest.raises( + Error, + match="No value supplied for 'intervals' and no default defined", + ): + measures.define_measure( + name="test", + denominator=patients.is_interesting, + ) + + +def test_names_must_be_unique(): + measures = Measures() + + measures.define_defaults( + denominator=patients.is_interesting, + intervals=[(date(2021, 1, 1), date(2021, 1, 31))], + group_by={"category": patients.category}, + ) + + measures.define_measure(name="test", numerator=patients.score) + with pytest.raises(Error, match="Measure already defined with name: test"): + measures.define_measure(name="test", numerator=patients.is_interesting) + + +@pytest.mark.parametrize( + "name", + [ + "", + "12346No Spaces", + "No-Dashes", + ], +) +def test_names_must_be_valid(name): + measures = Measures() + with pytest.raises( + Error, + match=( + "must start with a letter and contain only alphanumeric characters" + " and underscores" + ), + ): + measures.define_measure( + name=name, + numerator=patients.score, + denominator=patients.is_interesting, + intervals=[(date(2021, 1, 1), date(2021, 1, 31))], + group_by={"category": patients.category}, + ) + + +def test_group_by_columns_must_be_consistent(): + measures = Measures() + + measures.define_defaults( + numerator=patients.score, + denominator=patients.is_interesting, + intervals=[(date(2021, 1, 1), date(2021, 1, 31))], + ) + + measures.define_measure( + name="test_1", + group_by={ + "style": patients.style, + "category": patients.category, + }, + ) + with pytest.raises( + Error, match="Inconsistent definition for `group_by` column: category" + ): + measures.define_measure( + name="test_2", + group_by={ + "style": patients.style, + "category": patients.is_interesting, + }, + ) + + +def test_numerator_must_be_patient_series(): + measures = Measures() + + with pytest.raises( + TypeError, + match="invalid numerator:\nExpecting an ehrQL series, got type 'object'", + ): + measures.define_measure( + name="test", + numerator=object(), + denominator=patients.is_interesting, + intervals=[(date(2021, 1, 1), date(2021, 1, 31))], + group_by={"category": patients.category}, + ) + + +def test_denominator_must_be_int_or_bool_series(): + measures = Measures() + + with pytest.raises( + TypeError, + match=( + "invalid denominator:\n" + "Expecting a boolean or integer series, got series of type 'str'" + ), + ): + measures.define_measure( + name="test", + numerator=patients.score, + denominator=patients.style, + intervals=[(date(2021, 1, 1), date(2021, 1, 31))], + group_by={"category": patients.category}, + ) + + +@pytest.mark.parametrize( + "group_by,error", + [ + (1234, "`group_by` must be a dictionary"), + ({1234: patients.category}, "group_by` names must be strings"), + ({"My Group": patients.category}, "alphanumeric characters and underscores"), + ({"measure": patients.category}, "disallowed `group_by` column name: measure"), + ( + {"my_group": 1234}, + "invalid `group_by` value for 'my_group':\n" + "Expecting an ehrQL series, got type 'int'", + ), + ], +) +def test_invalid_group_by_is_rejected(group_by, error): + measures = Measures() + + with pytest.raises((TypeError, Error), match=error): + measures.define_measure( + name="test", + numerator=patients.score, + denominator=patients.is_interesting, + intervals=[(date(2021, 1, 1), date(2021, 1, 31))], + group_by=group_by, + ) + + +@pytest.mark.parametrize( + "intervals,error", + [ + (object(), "`intervals` must be a list"), + ([(1, 2)], "`intervals` must contain pairs of dates"), + ([(date(2020, 1, 1),)], "`intervals` must contain pairs of dates"), + ([(date(2021, 1, 1), date(2020, 1, 1))], "start date must be before end date"), + ( + months(12), + re.escape( + "You must supply a date using `months(value=12).starting_on('<DATE>')`" + " or `months(value=12).ending_on('<DATE>')`" + ), + ), + ], +) +def test_invalid_intervals_are_rejected(intervals, error): + measures = Measures() + + with pytest.raises(Error, match=error): + measures.define_measure( + name="test", + numerator=patients.score, + denominator=patients.is_interesting, + intervals=intervals, + group_by={"category": patients.category}, + ) + + +def test_configure_disclosure_control(): + measures = Measures() + measures.configure_disclosure_control(enabled=False) + assert not measures.disclosure_control_config.enabled