Switch to side-by-side view

--- a
+++ b/tests/unit/test_debugger.py
@@ -0,0 +1,436 @@
+import json
+import textwrap
+from datetime import date
+
+import pytest
+
+from ehrql import create_dataset, show
+from ehrql.debugger import (
+    activate_debug_context,
+    elements_are_related_series,
+    related_patient_columns_to_records,
+)
+from ehrql.query_engines.in_memory_database import PatientColumn
+from ehrql.tables import EventFrame, PatientFrame, Series, table
+
+
+def date_serializer(obj):
+    if isinstance(obj, date):
+        return obj.isoformat()
+    raise TypeError("Type not serializable")  # pragma: no cover
+
+
+def json_render_function(sequence, head=0, tail=0):
+    """Render as JSON, useful for testing."""
+    return json.dumps(sequence, indent=4, default=date_serializer)
+
+
+def test_show(capsys):
+    expected_output = textwrap.dedent(
+        """
+        Show line 3:
+
+        """
+    ).strip()
+
+    exec(
+        textwrap.dedent(
+            """
+            # line 2
+            show("Hello")
+            # line 4
+            """
+        )
+    )
+
+    captured = capsys.readouterr()
+    assert captured.err.strip().startswith(expected_output), captured.err
+
+
+def test_show_with_label(capsys):
+    expected_output = textwrap.dedent(
+        """
+        Show line 3: Number
+
+        """
+    ).strip()
+
+    exec(
+        textwrap.dedent(
+            """
+            # line 2
+            show(14, label="Number")
+            # line 4
+            """
+        )
+    )
+
+    captured = capsys.readouterr()
+    assert captured.err.strip().startswith(expected_output), captured.err
+
+
+def test_show_fails_for_non_ehrql_object(dummy_tables_path):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ):
+        with pytest.raises(TypeError):
+            show("Hello")
+
+
+def test_related_patient_columns_to_records_full_join():
+    c1 = PatientColumn.parse(
+        """
+        1 | 101
+        2 | 102
+        4 | 104
+        """
+    )
+    c2 = PatientColumn.parse(
+        """
+        1 | 201
+        2 | 202
+        3 | 203
+        """
+    )
+    r = list(related_patient_columns_to_records([c1, c2]))
+    r_expected = [
+        {"patient_id": 1, "series_1": 101, "series_2": 201},
+        {"patient_id": 2, "series_1": 102, "series_2": 202},
+        {"patient_id": 3, "series_1": "", "series_2": 203},
+        {"patient_id": 4, "series_1": 104, "series_2": ""},
+    ]
+    assert r == r_expected
+
+
+@table
+class patients(PatientFrame):
+    date_of_birth = Series(date)
+    date_of_death = Series(date)
+    sex = Series(str)
+
+
+@table
+class events(EventFrame):
+    date = Series(date)
+    code = Series(str)
+    test_result = Series(int)
+
+
+def init_dataset(**kwargs):
+    dataset = create_dataset()
+    for key, value in kwargs.items():
+        setattr(dataset, key, value)
+    return dataset
+
+
+@pytest.fixture(scope="session")
+def dummy_tables_path(tmp_path_factory):
+    tmp_path = tmp_path_factory.mktemp("dummy_tables")
+    tmp_path.joinpath("patients.csv").write_text(
+        textwrap.dedent(
+            """\
+            patient_id,date_of_birth,date_of_death,sex
+            1,1970-01-01,,male
+            2,1980-01-01,2020-01-01,female
+            """
+        )
+    )
+    tmp_path.joinpath("events.csv").write_text(
+        textwrap.dedent(
+            """\
+            patient_id,date,code,test_result
+            1,2010-01-01,abc,32
+            1,2020-01-01,def,
+            2,2005-01-01,abc,40
+            """
+        )
+    )
+    return tmp_path
+
+
+@pytest.mark.parametrize(
+    "expression,contents",
+    [
+        (
+            patients,
+            [
+                {
+                    "patient_id": 1,
+                    "date_of_birth": "1970-01-01",
+                    "date_of_death": "",
+                    "sex": "male",
+                },
+                {
+                    "patient_id": 2,
+                    "date_of_birth": "1980-01-01",
+                    "date_of_death": "2020-01-01",
+                    "sex": "female",
+                },
+            ],
+        ),
+        (
+            patients.date_of_birth,
+            [
+                {"patient_id": 1, "value": "1970-01-01"},
+                {"patient_id": 2, "value": "1980-01-01"},
+            ],
+        ),
+        (
+            init_dataset(
+                dob=patients.date_of_birth,
+                count=events.count_for_patient(),
+                dod=patients.date_of_death,
+            ),
+            [
+                {"patient_id": 1, "dob": "1970-01-01", "count": 2, "dod": ""},
+                {
+                    "patient_id": 2,
+                    "dob": "1980-01-01",
+                    "count": 1,
+                    "dod": "2020-01-01",
+                },
+            ],
+        ),
+    ],
+)
+def test_activate_debug_context(dummy_tables_path, expression, contents):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        assert json.loads(ctx.render(expression)) == contents
+
+
+@pytest.mark.parametrize(
+    "elements,expected",
+    [
+        ((patients.date_of_birth, patients.sex), True),
+        ((events.date, events.code), True),
+        ((patients.date_of_birth, events.count_for_patient()), True),
+        ((patients, patients.date_of_birth), False),
+        ((patients.date_of_birth, events.date), False),
+        ((patients.date_of_birth, {"some": "dict"}, patients.sex), False),
+    ],
+)
+def test_elements_are_related_series(elements, expected):
+    assert elements_are_related_series(elements) == expected
+
+
+def test_render_related_patient_series(dummy_tables_path):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        rendered = ctx.render(
+            patients.date_of_birth,
+            patients.sex,
+            events.count_for_patient(),
+            patients.date_of_death,
+        )
+    assert json.loads(rendered) == [
+        {
+            "patient_id": 1,
+            "series_1": "1970-01-01",
+            "series_2": "male",
+            "series_3": 2,
+            "series_4": "",
+        },
+        {
+            "patient_id": 2,
+            "series_1": "1980-01-01",
+            "series_2": "female",
+            "series_3": 1,
+            "series_4": "2020-01-01",
+        },
+    ]
+
+
+def test_render_related_event_series(dummy_tables_path):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        rendered = ctx.render(events.date, events.code, events.test_result)
+    assert json.loads(rendered) == [
+        {
+            "patient_id": 1,
+            "row_id": 1,
+            "series_1": "2010-01-01",
+            "series_2": "abc",
+            "series_3": 32,
+        },
+        {
+            "patient_id": 1,
+            "row_id": 2,
+            "series_1": "2020-01-01",
+            "series_2": "def",
+            "series_3": "",
+        },
+        {
+            "patient_id": 2,
+            "row_id": 3,
+            "series_1": "2005-01-01",
+            "series_2": "abc",
+            "series_3": 40,
+        },
+    ]
+
+
+def test_render_dataset_event_tables_with_population(dummy_tables_path):
+    dataset = create_dataset()
+    dataset.define_population(patients.sex == "male")
+    dataset.add_event_table("test", date=events.date, code=events.code)
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        rendered = ctx.render(dataset.test)
+    assert json.loads(rendered) == [
+        {
+            "patient_id": 1,
+            "row_id": 1,
+            "date": "2010-01-01",
+            "code": "abc",
+        },
+        {
+            "patient_id": 1,
+            "row_id": 2,
+            "date": "2020-01-01",
+            "code": "def",
+        },
+    ]
+
+
+def test_render_dataset_event_tables_without_population(dummy_tables_path):
+    dataset = create_dataset()
+    dataset.add_event_table("test", date=events.date, code=events.code)
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        rendered = ctx.render(dataset.test)
+    assert json.loads(rendered) == [
+        {
+            "patient_id": 1,
+            "row_id": 1,
+            "date": "2010-01-01",
+            "code": "abc",
+        },
+        {
+            "patient_id": 1,
+            "row_id": 2,
+            "date": "2020-01-01",
+            "code": "def",
+        },
+        {
+            "patient_id": 2,
+            "row_id": 3,
+            "date": "2005-01-01",
+            "code": "abc",
+        },
+    ]
+
+
+def test_render_date_difference(dummy_tables_path):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        rendered = ctx.render(patients.date_of_death - events.date)
+    assert json.loads(rendered) == [
+        {"patient_id": 1, "row_id": 1, "value": ""},
+        {"patient_id": 1, "row_id": 2, "value": ""},
+        {"patient_id": 2, "row_id": 3, "value": "5478 days"},
+    ]
+
+
+def test_render_related_date_difference_patient_series(dummy_tables_path):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        rendered = ctx.render(
+            "2024-01-01" - patients.date_of_birth,
+            patients.sex,
+        )
+    assert json.loads(rendered) == [
+        {"patient_id": 1, "series_1": "19723 days", "series_2": "male"},
+        {"patient_id": 2, "series_1": "16071 days", "series_2": "female"},
+    ]
+
+
+def test_render_related_date_difference_event_series(dummy_tables_path):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ) as ctx:
+        rendered = ctx.render(
+            events.date - patients.date_of_birth,
+            events.code,
+        )
+    assert json.loads(rendered) == [
+        {"patient_id": 1, "row_id": 1, "series_1": "14610 days", "series_2": "abc"},
+        {"patient_id": 1, "row_id": 2, "series_1": "18262 days", "series_2": "def"},
+        {"patient_id": 2, "row_id": 3, "series_1": "9132 days", "series_2": "abc"},
+    ]
+
+
+@pytest.mark.parametrize(
+    "example_input",
+    [
+        ((patients, patients.date_of_birth)),
+        ((patients.date_of_birth, events.date)),
+        ((patients.date_of_birth, {"some": "dict"}, patients.sex)),
+        ((init_dataset(), patients.date_of_birth, patients.sex)),
+    ],
+)
+def test_show_fails_for_mismatched_inputs(example_input):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ):
+        with pytest.raises(TypeError):
+            assert show(*example_input)
+
+
+@pytest.mark.parametrize(
+    "example_input",
+    [
+        ((patients.date_of_birth, events.count_for_patient())),
+        ((patients.date_of_birth, patients.sex)),
+        ((events.date, events.code)),
+        ((patients.date_of_birth, patients.sex == "male")),
+        ((events.date, events.code == "123400")),
+    ],
+)
+def test_show_does_not_raise_error_for_series_from_same_domain(
+    dummy_tables_path, example_input
+):
+    with activate_debug_context(
+        dummy_tables_path=dummy_tables_path,
+        render_function=json_render_function,
+    ):
+        show(example_input[0], *example_input[1:])
+
+
+def test_show_not_run_outside_debug_context(capsys):
+    expected_output = textwrap.dedent(
+        """
+        Show line 3:
+         - show() ignored because we're not running in debug mode
+        """
+    ).strip()
+
+    exec(
+        textwrap.dedent(
+            """
+            # line 2
+            show(patients.date_of_birth, patients.sex)
+            # line 4
+            """
+        )
+    )
+
+    captured = capsys.readouterr()
+    assert captured.err.strip() == expected_output, captured.err