--- a +++ b/tests/unit/test_debugger.py @@ -0,0 +1,436 @@ +import json +import textwrap +from datetime import date + +import pytest + +from ehrql import create_dataset, show +from ehrql.debugger import ( + activate_debug_context, + elements_are_related_series, + related_patient_columns_to_records, +) +from ehrql.query_engines.in_memory_database import PatientColumn +from ehrql.tables import EventFrame, PatientFrame, Series, table + + +def date_serializer(obj): + if isinstance(obj, date): + return obj.isoformat() + raise TypeError("Type not serializable") # pragma: no cover + + +def json_render_function(sequence, head=0, tail=0): + """Render as JSON, useful for testing.""" + return json.dumps(sequence, indent=4, default=date_serializer) + + +def test_show(capsys): + expected_output = textwrap.dedent( + """ + Show line 3: + + """ + ).strip() + + exec( + textwrap.dedent( + """ + # line 2 + show("Hello") + # line 4 + """ + ) + ) + + captured = capsys.readouterr() + assert captured.err.strip().startswith(expected_output), captured.err + + +def test_show_with_label(capsys): + expected_output = textwrap.dedent( + """ + Show line 3: Number + + """ + ).strip() + + exec( + textwrap.dedent( + """ + # line 2 + show(14, label="Number") + # line 4 + """ + ) + ) + + captured = capsys.readouterr() + assert captured.err.strip().startswith(expected_output), captured.err + + +def test_show_fails_for_non_ehrql_object(dummy_tables_path): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ): + with pytest.raises(TypeError): + show("Hello") + + +def test_related_patient_columns_to_records_full_join(): + c1 = PatientColumn.parse( + """ + 1 | 101 + 2 | 102 + 4 | 104 + """ + ) + c2 = PatientColumn.parse( + """ + 1 | 201 + 2 | 202 + 3 | 203 + """ + ) + r = list(related_patient_columns_to_records([c1, c2])) + r_expected = [ + {"patient_id": 1, "series_1": 101, "series_2": 201}, + {"patient_id": 2, "series_1": 102, "series_2": 202}, + {"patient_id": 3, "series_1": "", "series_2": 203}, + {"patient_id": 4, "series_1": 104, "series_2": ""}, + ] + assert r == r_expected + + +@table +class patients(PatientFrame): + date_of_birth = Series(date) + date_of_death = Series(date) + sex = Series(str) + + +@table +class events(EventFrame): + date = Series(date) + code = Series(str) + test_result = Series(int) + + +def init_dataset(**kwargs): + dataset = create_dataset() + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + +@pytest.fixture(scope="session") +def dummy_tables_path(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("dummy_tables") + tmp_path.joinpath("patients.csv").write_text( + textwrap.dedent( + """\ + patient_id,date_of_birth,date_of_death,sex + 1,1970-01-01,,male + 2,1980-01-01,2020-01-01,female + """ + ) + ) + tmp_path.joinpath("events.csv").write_text( + textwrap.dedent( + """\ + patient_id,date,code,test_result + 1,2010-01-01,abc,32 + 1,2020-01-01,def, + 2,2005-01-01,abc,40 + """ + ) + ) + return tmp_path + + +@pytest.mark.parametrize( + "expression,contents", + [ + ( + patients, + [ + { + "patient_id": 1, + "date_of_birth": "1970-01-01", + "date_of_death": "", + "sex": "male", + }, + { + "patient_id": 2, + "date_of_birth": "1980-01-01", + "date_of_death": "2020-01-01", + "sex": "female", + }, + ], + ), + ( + patients.date_of_birth, + [ + {"patient_id": 1, "value": "1970-01-01"}, + {"patient_id": 2, "value": "1980-01-01"}, + ], + ), + ( + init_dataset( + dob=patients.date_of_birth, + count=events.count_for_patient(), + dod=patients.date_of_death, + ), + [ + {"patient_id": 1, "dob": "1970-01-01", "count": 2, "dod": ""}, + { + "patient_id": 2, + "dob": "1980-01-01", + "count": 1, + "dod": "2020-01-01", + }, + ], + ), + ], +) +def test_activate_debug_context(dummy_tables_path, expression, contents): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + assert json.loads(ctx.render(expression)) == contents + + +@pytest.mark.parametrize( + "elements,expected", + [ + ((patients.date_of_birth, patients.sex), True), + ((events.date, events.code), True), + ((patients.date_of_birth, events.count_for_patient()), True), + ((patients, patients.date_of_birth), False), + ((patients.date_of_birth, events.date), False), + ((patients.date_of_birth, {"some": "dict"}, patients.sex), False), + ], +) +def test_elements_are_related_series(elements, expected): + assert elements_are_related_series(elements) == expected + + +def test_render_related_patient_series(dummy_tables_path): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + rendered = ctx.render( + patients.date_of_birth, + patients.sex, + events.count_for_patient(), + patients.date_of_death, + ) + assert json.loads(rendered) == [ + { + "patient_id": 1, + "series_1": "1970-01-01", + "series_2": "male", + "series_3": 2, + "series_4": "", + }, + { + "patient_id": 2, + "series_1": "1980-01-01", + "series_2": "female", + "series_3": 1, + "series_4": "2020-01-01", + }, + ] + + +def test_render_related_event_series(dummy_tables_path): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + rendered = ctx.render(events.date, events.code, events.test_result) + assert json.loads(rendered) == [ + { + "patient_id": 1, + "row_id": 1, + "series_1": "2010-01-01", + "series_2": "abc", + "series_3": 32, + }, + { + "patient_id": 1, + "row_id": 2, + "series_1": "2020-01-01", + "series_2": "def", + "series_3": "", + }, + { + "patient_id": 2, + "row_id": 3, + "series_1": "2005-01-01", + "series_2": "abc", + "series_3": 40, + }, + ] + + +def test_render_dataset_event_tables_with_population(dummy_tables_path): + dataset = create_dataset() + dataset.define_population(patients.sex == "male") + dataset.add_event_table("test", date=events.date, code=events.code) + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + rendered = ctx.render(dataset.test) + assert json.loads(rendered) == [ + { + "patient_id": 1, + "row_id": 1, + "date": "2010-01-01", + "code": "abc", + }, + { + "patient_id": 1, + "row_id": 2, + "date": "2020-01-01", + "code": "def", + }, + ] + + +def test_render_dataset_event_tables_without_population(dummy_tables_path): + dataset = create_dataset() + dataset.add_event_table("test", date=events.date, code=events.code) + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + rendered = ctx.render(dataset.test) + assert json.loads(rendered) == [ + { + "patient_id": 1, + "row_id": 1, + "date": "2010-01-01", + "code": "abc", + }, + { + "patient_id": 1, + "row_id": 2, + "date": "2020-01-01", + "code": "def", + }, + { + "patient_id": 2, + "row_id": 3, + "date": "2005-01-01", + "code": "abc", + }, + ] + + +def test_render_date_difference(dummy_tables_path): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + rendered = ctx.render(patients.date_of_death - events.date) + assert json.loads(rendered) == [ + {"patient_id": 1, "row_id": 1, "value": ""}, + {"patient_id": 1, "row_id": 2, "value": ""}, + {"patient_id": 2, "row_id": 3, "value": "5478 days"}, + ] + + +def test_render_related_date_difference_patient_series(dummy_tables_path): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + rendered = ctx.render( + "2024-01-01" - patients.date_of_birth, + patients.sex, + ) + assert json.loads(rendered) == [ + {"patient_id": 1, "series_1": "19723 days", "series_2": "male"}, + {"patient_id": 2, "series_1": "16071 days", "series_2": "female"}, + ] + + +def test_render_related_date_difference_event_series(dummy_tables_path): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ) as ctx: + rendered = ctx.render( + events.date - patients.date_of_birth, + events.code, + ) + assert json.loads(rendered) == [ + {"patient_id": 1, "row_id": 1, "series_1": "14610 days", "series_2": "abc"}, + {"patient_id": 1, "row_id": 2, "series_1": "18262 days", "series_2": "def"}, + {"patient_id": 2, "row_id": 3, "series_1": "9132 days", "series_2": "abc"}, + ] + + +@pytest.mark.parametrize( + "example_input", + [ + ((patients, patients.date_of_birth)), + ((patients.date_of_birth, events.date)), + ((patients.date_of_birth, {"some": "dict"}, patients.sex)), + ((init_dataset(), patients.date_of_birth, patients.sex)), + ], +) +def test_show_fails_for_mismatched_inputs(example_input): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ): + with pytest.raises(TypeError): + assert show(*example_input) + + +@pytest.mark.parametrize( + "example_input", + [ + ((patients.date_of_birth, events.count_for_patient())), + ((patients.date_of_birth, patients.sex)), + ((events.date, events.code)), + ((patients.date_of_birth, patients.sex == "male")), + ((events.date, events.code == "123400")), + ], +) +def test_show_does_not_raise_error_for_series_from_same_domain( + dummy_tables_path, example_input +): + with activate_debug_context( + dummy_tables_path=dummy_tables_path, + render_function=json_render_function, + ): + show(example_input[0], *example_input[1:]) + + +def test_show_not_run_outside_debug_context(capsys): + expected_output = textwrap.dedent( + """ + Show line 3: + - show() ignored because we're not running in debug mode + """ + ).strip() + + exec( + textwrap.dedent( + """ + # line 2 + show(patients.date_of_birth, patients.sex) + # line 4 + """ + ) + ) + + captured = capsys.readouterr() + assert captured.err.strip() == expected_output, captured.err