Switch to side-by-side view

--- a
+++ b/tests/integration/backends/test_emis.py
@@ -0,0 +1,612 @@
+import re
+from datetime import date, datetime
+
+import sqlalchemy
+
+from ehrql import create_dataset
+from ehrql.backends.emis import EMISBackend
+from ehrql.tables import PatientFrame, Series, emis, table_from_rows
+from ehrql.tables.raw import emis as emis_raw
+from ehrql.utils.sqlalchemy_query_utils import CreateTableAs, GeneratedTable
+from tests.lib.emis_schema import (
+    ImmunisationAllOrgsV2,
+    MedicationAllOrgsV2,
+    ObservationAllOrgsV2,
+    OnsView,
+    PatientAllOrgsV2,
+)
+
+from .helpers import (
+    assert_tests_exhaustive,
+    assert_types_correct,
+    get_all_backend_columns,
+    register_test_for,
+)
+
+
+def test_backend_columns_have_correct_types(trino_database):
+    columns_with_types = get_all_backend_columns_with_types(trino_database)
+    assert_types_correct(columns_with_types, trino_database)
+
+
+def get_all_backend_columns_with_types(trino_database):
+    """
+    For every column on every table we expose in the backend, yield the SQLAlchemy type
+    instance we expect to use for that column together with the type information that
+    database has for that column so we can check they're compatible
+    """
+    table_names = set()
+    column_types = {}
+    queries = []
+    for table, columns in get_all_backend_columns(EMISBackend()):
+        table_names.add(table)
+        column_types.update({(table, c.key): c.type for c in columns})
+        # Construct a query which selects every column in the table
+        select_query = sqlalchemy.select(*[c.label(c.key) for c in columns])
+        # Write the results of that query into a temporary table (it will be empty but
+        # that's fine, we just want the types)
+        # Trino doesn't support actual temporary tables, so really this temp table is
+        # a real table that we drop after the test
+        temp_table_name = f"temp_{table}"
+        temp_table = GeneratedTable.from_query(temp_table_name, select_query)
+        queries.append(
+            (temp_table_name, temp_table, CreateTableAs(temp_table, select_query))
+        )
+    # Create all the underlying tables in the database without populating them
+    trino_database.setup(metadata=PatientAllOrgsV2.metadata)
+    with trino_database.engine().connect() as connection:
+        # Create our "temporary" tables
+        for temp_table_name, temp_table, query in queries:
+            connection.execute(query)
+            # Get the column names, types and collations for all columns in those tables
+            query = sqlalchemy.text(
+                """
+                SELECT column_name, data_type
+                FROM information_schema.columns
+                WHERE table_name=:t
+                """
+            )
+            results = list(connection.execute(query, {"t": temp_table_name}))
+            table_name = temp_table_name.replace("temp_", "")
+            for column, type_name in results:
+                column_type = column_types[table_name, column]
+                column_args = {"type": type_name}
+                yield table_name, column, column_type, column_args
+
+            # Drop the temp table
+            temp_table.drop(trino_database.engine())
+
+
+@register_test_for(emis.clinical_events)
+def test_clinical_events(select_all_emis):
+    results = select_all_emis(
+        PatientAllOrgsV2(registration_id="1"),
+        PatientAllOrgsV2(registration_id="2"),
+        ObservationAllOrgsV2(
+            registration_id="1",
+            effective_date=datetime(2020, 10, 20, 14, 30, 5),
+            snomed_concept_id=123,
+            value_pq_1=0.5,
+        ),
+        ObservationAllOrgsV2(
+            registration_id="2",
+            effective_date=datetime(2022, 1, 15, 12, 30, 5),
+            snomed_concept_id=567,
+            value_pq_1=None,
+        ),
+    )
+    assert results == [
+        {
+            "patient_id": "1",
+            "date": date(2020, 10, 20),
+            "snomedct_code": "123",
+            "numeric_value": 0.5,
+        },
+        {
+            "patient_id": "2",
+            "date": date(2022, 1, 15),
+            "snomedct_code": "567",
+            "numeric_value": None,
+        },
+    ]
+
+
+@register_test_for(emis.medications)
+def test_medications(select_all_emis):
+    results = select_all_emis(
+        PatientAllOrgsV2(registration_id="1"),
+        PatientAllOrgsV2(registration_id="2"),
+        MedicationAllOrgsV2(
+            registration_id="1",
+            effective_date=datetime(2020, 10, 20, 14, 30, 5),
+            snomed_concept_id=123,
+        ),
+        MedicationAllOrgsV2(
+            registration_id="2",
+            effective_date=datetime(2022, 1, 15, 12, 30, 5),
+            snomed_concept_id=567,
+        ),
+    )
+    assert results == [
+        {
+            "patient_id": "1",
+            "date": date(2020, 10, 20),
+            "dmd_code": "123",
+        },
+        {
+            "patient_id": "2",
+            "date": date(2022, 1, 15),
+            "dmd_code": "567",
+        },
+    ]
+
+
+@register_test_for(emis.ons_deaths)
+def test_ons_deaths(select_all_emis):
+    results = select_all_emis(
+        PatientAllOrgsV2(registration_id="1", nhs_no="nhs1"),
+        PatientAllOrgsV2(registration_id="2", nhs_no="nhs2"),
+        PatientAllOrgsV2(registration_id="3", nhs_no="nhs3"),
+        # duplicate registration_id, patient omitted
+        PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"),
+        PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"),
+        OnsView(
+            pseudonhsnumber="nhs1",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        # older upload date, ignored
+        OnsView(
+            pseudonhsnumber="nhs1",
+            upload_date="20220101",
+            reg_stat_dod="20210101",
+            icd10u="wxy",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        # same patient, different date of death; earliest dod is selected
+        OnsView(
+            pseudonhsnumber="nhs2",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        OnsView(
+            pseudonhsnumber="nhs2",
+            upload_date="20230101",
+            reg_stat_dod="20220102",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        # same patient, same date of death; lexically smallest cause of death is selected
+        OnsView(
+            pseudonhsnumber="nhs3",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="abc",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        OnsView(
+            pseudonhsnumber="nhs3",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        # duplicate in patients table, excluded
+        OnsView(
+            pseudonhsnumber="nhs4",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+    )
+    assert results == [
+        {
+            "patient_id": "1",
+            "date": date(2022, 1, 1),
+            "underlying_cause_of_death": "xyz",
+            "cause_of_death_01": "abc",
+            "cause_of_death_02": "def",
+            "cause_of_death_03": None,
+            "cause_of_death_04": None,
+            "cause_of_death_05": None,
+            "cause_of_death_06": None,
+            "cause_of_death_07": None,
+            "cause_of_death_08": None,
+            "cause_of_death_09": None,
+            "cause_of_death_10": None,
+            "cause_of_death_11": None,
+            "cause_of_death_12": None,
+            "cause_of_death_13": None,
+            "cause_of_death_14": None,
+            "cause_of_death_15": None,
+        },
+        {
+            "patient_id": "2",
+            "date": date(2022, 1, 1),
+            "underlying_cause_of_death": "xyz",
+            "cause_of_death_01": "abc",
+            "cause_of_death_02": "def",
+            "cause_of_death_03": None,
+            "cause_of_death_04": None,
+            "cause_of_death_05": None,
+            "cause_of_death_06": None,
+            "cause_of_death_07": None,
+            "cause_of_death_08": None,
+            "cause_of_death_09": None,
+            "cause_of_death_10": None,
+            "cause_of_death_11": None,
+            "cause_of_death_12": None,
+            "cause_of_death_13": None,
+            "cause_of_death_14": None,
+            "cause_of_death_15": None,
+        },
+        {
+            "patient_id": "3",
+            "date": date(2022, 1, 1),
+            "underlying_cause_of_death": "abc",
+            "cause_of_death_01": "abc",
+            "cause_of_death_02": "def",
+            "cause_of_death_03": None,
+            "cause_of_death_04": None,
+            "cause_of_death_05": None,
+            "cause_of_death_06": None,
+            "cause_of_death_07": None,
+            "cause_of_death_08": None,
+            "cause_of_death_09": None,
+            "cause_of_death_10": None,
+            "cause_of_death_11": None,
+            "cause_of_death_12": None,
+            "cause_of_death_13": None,
+            "cause_of_death_14": None,
+            "cause_of_death_15": None,
+        },
+    ]
+
+
+@register_test_for(emis_raw.ons_deaths)
+def test_ons_deaths_raw(select_all_emis):
+    results = select_all_emis(
+        PatientAllOrgsV2(registration_id="1", nhs_no="nhs1"),
+        PatientAllOrgsV2(registration_id="2", nhs_no="nhs2"),
+        PatientAllOrgsV2(registration_id="3", nhs_no="nhs3"),
+        # duplicate registration_id, patient omitted
+        PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"),
+        PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"),
+        OnsView(
+            pseudonhsnumber="nhs1",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        # older upload date, ignored
+        OnsView(
+            pseudonhsnumber="nhs1",
+            upload_date="20220101",
+            reg_stat_dod="20210101",
+            icd10u="wxy",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        # same patient, different date of death; earliest dod is selected
+        OnsView(
+            pseudonhsnumber="nhs2",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        OnsView(
+            pseudonhsnumber="nhs2",
+            upload_date="20230101",
+            reg_stat_dod="20220102",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        # same patient, same date of death; lexically smallest cause of death is selected
+        OnsView(
+            pseudonhsnumber="nhs3",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="abc",
+            icd10001="abc",
+            icd10002="def",
+        ),
+        OnsView(
+            pseudonhsnumber="nhs3",
+            upload_date="20230101",
+            reg_stat_dod="20220101",
+            icd10u="xyz",
+            icd10001="abc",
+            icd10002="def",
+        ),
+    )
+
+    # results include duplicates, but still omit earlier uploads and duplicate
+    # registrations
+    results_summary = [(result["patient_id"], result["date"]) for result in results]
+    assert results_summary == [
+        ("1", date(2022, 1, 1)),
+        ("2", date(2022, 1, 1)),
+        ("2", date(2022, 1, 2)),
+        ("3", date(2022, 1, 1)),
+        ("3", date(2022, 1, 1)),
+    ]
+
+
+@register_test_for(emis.patients)
+def test_patients(select_all_emis):
+    results = select_all_emis(
+        PatientAllOrgsV2(
+            registration_id="1",
+            date_of_birth=date(2020, 1, 1),
+            gender=1,
+            hashed_organisation="1A2B3C",
+            registered_date=date(2021, 3, 1),
+            rural_urban=1,
+            imd_rank=500,
+        ),
+        # duplicate registration ids are ignored
+        PatientAllOrgsV2(
+            registration_id="2",
+            date_of_birth=date(2020, 1, 1),
+            gender=1,
+            hashed_organisation="1A2B3C",
+            registered_date=date(2021, 3, 1),
+        ),
+        PatientAllOrgsV2(
+            registration_id="2",
+            date_of_birth=date(2020, 1, 1),
+            gender=1,
+            hashed_organisation="1A2B3C",
+            registered_date=date(2021, 3, 1),
+        ),
+        PatientAllOrgsV2(
+            registration_id="3",
+            date_of_birth=date(1960, 1, 1),
+            date_of_death=date(2020, 1, 1),
+            gender=2,
+            hashed_organisation="1A2B3C",
+            registered_date=date(1960, 3, 1),
+        ),
+        PatientAllOrgsV2(
+            registration_id="4",
+            date_of_birth=date(2020, 1, 1),
+            gender=0,
+            hashed_organisation="1A2B3C",
+            registered_date=date(2021, 3, 1),
+        ),
+        PatientAllOrgsV2(
+            registration_id="5",
+            date_of_birth=date(1978, 10, 13),
+            gender=9,
+            hashed_organisation="1A2B3C",
+            registered_date=date(2021, 3, 1),
+        ),
+    )
+
+    expected = [
+        {
+            "patient_id": "1",
+            "date_of_birth": date(2020, 1, 1),
+            "sex": "male",
+            "date_of_death": None,
+            "registration_start_date": date(2021, 3, 1),
+            "registration_end_date": None,
+            "practice_pseudo_id": "1A2B3C",
+            "rural_urban_classification": 1,
+            "imd_rounded": 500,
+        },
+        {
+            "patient_id": "3",
+            "date_of_birth": date(1960, 1, 1),
+            "sex": "female",
+            "date_of_death": date(2020, 1, 1),
+            "registration_start_date": date(1960, 3, 1),
+            "registration_end_date": None,
+            "practice_pseudo_id": "1A2B3C",
+            "rural_urban_classification": None,
+            "imd_rounded": None,
+        },
+        {
+            "patient_id": "4",
+            "date_of_birth": date(2020, 1, 1),
+            "sex": "unknown",
+            "date_of_death": None,
+            "registration_start_date": date(2021, 3, 1),
+            "registration_end_date": None,
+            "practice_pseudo_id": "1A2B3C",
+            "rural_urban_classification": None,
+            "imd_rounded": None,
+        },
+        {
+            "patient_id": "5",
+            "date_of_birth": date(1978, 10, 13),
+            "sex": "unknown",
+            "date_of_death": None,
+            "registration_start_date": date(2021, 3, 1),
+            "registration_end_date": None,
+            "practice_pseudo_id": "1A2B3C",
+            "rural_urban_classification": None,
+            "imd_rounded": None,
+        },
+    ]
+    assert results == expected
+
+
+@register_test_for(emis.practice_registrations)
+def test_practice_registrations(select_all_emis):
+    results = select_all_emis(
+        PatientAllOrgsV2(
+            registration_id="1",
+            hashed_organisation="1f",
+            registered_date=date(2021, 3, 1),
+            registration_end_date=date(2022, 4, 2),
+        ),
+        PatientAllOrgsV2(
+            registration_id="1",
+            hashed_organisation="10A",
+            registered_date=date(2022, 4, 3),
+            registration_end_date=None,
+        ),
+        PatientAllOrgsV2(
+            registration_id="2",
+            hashed_organisation="123ABC",
+            registered_date=date(2000, 1, 1),
+            registration_end_date=date(2020, 1, 1),
+        ),
+    )
+
+    expected = [
+        {
+            "patient_id": "1",
+            "start_date": date(2021, 3, 1),
+            "end_date": date(2022, 4, 2),
+            # The core `practice_registrations` table defines `practice_pseudo_id` as an
+            # int, so we have to convert from hex strings to ints here
+            "practice_pseudo_id": 31,
+        },
+        {
+            "patient_id": "1",
+            "start_date": date(2022, 4, 3),
+            "end_date": None,
+            "practice_pseudo_id": 266,
+        },
+        {
+            "patient_id": "2",
+            "start_date": date(2000, 1, 1),
+            "end_date": date(2020, 1, 1),
+            "practice_pseudo_id": 1194684,
+        },
+    ]
+
+    # Trino doesn't return results in a stable order
+    def sort(lst):
+        return sorted(lst, key=lambda i: (i["patient_id"], i["practice_pseudo_id"]))
+
+    assert sort(results) == sort(expected)
+
+
+@register_test_for(emis.vaccinations)
+def test_vaccinations(select_all_emis):
+    results = select_all_emis(
+        PatientAllOrgsV2(registration_id="1"),
+        PatientAllOrgsV2(registration_id="2"),
+        PatientAllOrgsV2(registration_id="3"),
+        ImmunisationAllOrgsV2(
+            registration_id="1",
+            effective_date=datetime(2020, 10, 20, 14, 30, 5),
+            snomed_concept_id=123,
+        ),
+        ImmunisationAllOrgsV2(
+            registration_id="2",
+            effective_date=datetime(2021, 3, 23, 23, 30, 5),
+            snomed_concept_id=456,
+        ),
+        ImmunisationAllOrgsV2(
+            registration_id="2",
+            effective_date=datetime(2022, 1, 15, 12, 30, 5),
+            snomed_concept_id=567,
+        ),
+    )
+    assert results == [
+        {
+            "patient_id": "1",
+            "date": date(2020, 10, 20),
+            "procedure_code": "123",
+        },
+        {
+            "patient_id": "2",
+            "date": date(2021, 3, 23),
+            "procedure_code": "456",
+        },
+        {
+            "patient_id": "2",
+            "date": date(2022, 1, 15),
+            "procedure_code": "567",
+        },
+    ]
+
+
+def test_registered_tests_are_exhaustive():
+    assert_tests_exhaustive(EMISBackend())
+
+
+def test_generated_table_includes_organisation_hash(trino_database):
+    # This tests that EMIS's generated inline and temporary tables include a column
+    # "hashed_organisation", where every row's value is the value of the
+    # EMIS_ORGANISATION_HASH environment variable
+    ORG_HASH = "testing_123"
+
+    # Note that currently inline data tables always make patient_id an integer
+    # so in this test, our patient ids from the backend DB are coerced to ints
+    # In reality, this means inline tables won't be able to handle real EMIS
+    # data (where patient ids are strings) but this will be dealt with
+    # later
+    # https://github.com/opensafely-core/ehrql/issues/743
+    inline_data = [
+        (1, 100),
+        (2, 200),
+    ]
+
+    @table_from_rows(inline_data)
+    class t(PatientFrame):
+        n = Series(int)
+
+    dataset = create_dataset()
+    dataset.define_population(t.exists_for_patient())
+    dataset.n = t.n
+
+    backend = EMISBackend(config={"EMIS_ORGANISATION_HASH": ORG_HASH})
+    query_engine = backend.query_engine_class(
+        trino_database.host_url(),
+        backend=backend,
+    )
+
+    # Monkey patch on our own `execute_query_no_results` method which records the contents of
+    # generated tables
+    orig_execute_query = query_engine.execute_query_no_results
+    found_tables = {}
+
+    def execute_query_no_results(connection, query, *args, **kwargs):
+        # Before we drop any inline or temporary tables we grab the contents of their
+        # `hashed_organisation` column (which also serves as a test that they _have_
+        # such a column)
+        if match := re.search(r"DROP .+\b(\w+_(inline_data|tmp)_\w+)\b", str(query)):
+            table_name = match.group(1)
+            results = connection.execute(
+                sqlalchemy.text(f"SELECT hashed_organisation FROM {table_name}")
+            )
+            found_tables[match] = [row[0] for row in results]
+        orig_execute_query(connection, query, *args, **kwargs)
+
+    query_engine.execute_query_no_results = execute_query_no_results
+
+    # Consume the results to execute all queries
+    for table in query_engine.get_results_tables(dataset._compile()):
+        list(table)
+
+    for results in found_tables.values():
+        # Empty or single-row tables aren't really exercising the code properly so check
+        # we're not inadvertantly using those
+        assert len(results) > 1
+        # Assert that the organisation hash appears in every row
+        assert results == [ORG_HASH] * len(results)
+
+    # Check that we have examples of both the table types we're interested in
+    assert {match.group(2) for match in found_tables.keys()} == {"inline_data", "tmp"}