--- a +++ b/tests/integration/backends/test_emis.py @@ -0,0 +1,612 @@ +import re +from datetime import date, datetime + +import sqlalchemy + +from ehrql import create_dataset +from ehrql.backends.emis import EMISBackend +from ehrql.tables import PatientFrame, Series, emis, table_from_rows +from ehrql.tables.raw import emis as emis_raw +from ehrql.utils.sqlalchemy_query_utils import CreateTableAs, GeneratedTable +from tests.lib.emis_schema import ( + ImmunisationAllOrgsV2, + MedicationAllOrgsV2, + ObservationAllOrgsV2, + OnsView, + PatientAllOrgsV2, +) + +from .helpers import ( + assert_tests_exhaustive, + assert_types_correct, + get_all_backend_columns, + register_test_for, +) + + +def test_backend_columns_have_correct_types(trino_database): + columns_with_types = get_all_backend_columns_with_types(trino_database) + assert_types_correct(columns_with_types, trino_database) + + +def get_all_backend_columns_with_types(trino_database): + """ + For every column on every table we expose in the backend, yield the SQLAlchemy type + instance we expect to use for that column together with the type information that + database has for that column so we can check they're compatible + """ + table_names = set() + column_types = {} + queries = [] + for table, columns in get_all_backend_columns(EMISBackend()): + table_names.add(table) + column_types.update({(table, c.key): c.type for c in columns}) + # Construct a query which selects every column in the table + select_query = sqlalchemy.select(*[c.label(c.key) for c in columns]) + # Write the results of that query into a temporary table (it will be empty but + # that's fine, we just want the types) + # Trino doesn't support actual temporary tables, so really this temp table is + # a real table that we drop after the test + temp_table_name = f"temp_{table}" + temp_table = GeneratedTable.from_query(temp_table_name, select_query) + queries.append( + (temp_table_name, temp_table, CreateTableAs(temp_table, select_query)) + ) + # Create all the underlying tables in the database without populating them + trino_database.setup(metadata=PatientAllOrgsV2.metadata) + with trino_database.engine().connect() as connection: + # Create our "temporary" tables + for temp_table_name, temp_table, query in queries: + connection.execute(query) + # Get the column names, types and collations for all columns in those tables + query = sqlalchemy.text( + """ + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name=:t + """ + ) + results = list(connection.execute(query, {"t": temp_table_name})) + table_name = temp_table_name.replace("temp_", "") + for column, type_name in results: + column_type = column_types[table_name, column] + column_args = {"type": type_name} + yield table_name, column, column_type, column_args + + # Drop the temp table + temp_table.drop(trino_database.engine()) + + +@register_test_for(emis.clinical_events) +def test_clinical_events(select_all_emis): + results = select_all_emis( + PatientAllOrgsV2(registration_id="1"), + PatientAllOrgsV2(registration_id="2"), + ObservationAllOrgsV2( + registration_id="1", + effective_date=datetime(2020, 10, 20, 14, 30, 5), + snomed_concept_id=123, + value_pq_1=0.5, + ), + ObservationAllOrgsV2( + registration_id="2", + effective_date=datetime(2022, 1, 15, 12, 30, 5), + snomed_concept_id=567, + value_pq_1=None, + ), + ) + assert results == [ + { + "patient_id": "1", + "date": date(2020, 10, 20), + "snomedct_code": "123", + "numeric_value": 0.5, + }, + { + "patient_id": "2", + "date": date(2022, 1, 15), + "snomedct_code": "567", + "numeric_value": None, + }, + ] + + +@register_test_for(emis.medications) +def test_medications(select_all_emis): + results = select_all_emis( + PatientAllOrgsV2(registration_id="1"), + PatientAllOrgsV2(registration_id="2"), + MedicationAllOrgsV2( + registration_id="1", + effective_date=datetime(2020, 10, 20, 14, 30, 5), + snomed_concept_id=123, + ), + MedicationAllOrgsV2( + registration_id="2", + effective_date=datetime(2022, 1, 15, 12, 30, 5), + snomed_concept_id=567, + ), + ) + assert results == [ + { + "patient_id": "1", + "date": date(2020, 10, 20), + "dmd_code": "123", + }, + { + "patient_id": "2", + "date": date(2022, 1, 15), + "dmd_code": "567", + }, + ] + + +@register_test_for(emis.ons_deaths) +def test_ons_deaths(select_all_emis): + results = select_all_emis( + PatientAllOrgsV2(registration_id="1", nhs_no="nhs1"), + PatientAllOrgsV2(registration_id="2", nhs_no="nhs2"), + PatientAllOrgsV2(registration_id="3", nhs_no="nhs3"), + # duplicate registration_id, patient omitted + PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"), + PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"), + OnsView( + pseudonhsnumber="nhs1", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + # older upload date, ignored + OnsView( + pseudonhsnumber="nhs1", + upload_date="20220101", + reg_stat_dod="20210101", + icd10u="wxy", + icd10001="abc", + icd10002="def", + ), + # same patient, different date of death; earliest dod is selected + OnsView( + pseudonhsnumber="nhs2", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + OnsView( + pseudonhsnumber="nhs2", + upload_date="20230101", + reg_stat_dod="20220102", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + # same patient, same date of death; lexically smallest cause of death is selected + OnsView( + pseudonhsnumber="nhs3", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="abc", + icd10001="abc", + icd10002="def", + ), + OnsView( + pseudonhsnumber="nhs3", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + # duplicate in patients table, excluded + OnsView( + pseudonhsnumber="nhs4", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + ) + assert results == [ + { + "patient_id": "1", + "date": date(2022, 1, 1), + "underlying_cause_of_death": "xyz", + "cause_of_death_01": "abc", + "cause_of_death_02": "def", + "cause_of_death_03": None, + "cause_of_death_04": None, + "cause_of_death_05": None, + "cause_of_death_06": None, + "cause_of_death_07": None, + "cause_of_death_08": None, + "cause_of_death_09": None, + "cause_of_death_10": None, + "cause_of_death_11": None, + "cause_of_death_12": None, + "cause_of_death_13": None, + "cause_of_death_14": None, + "cause_of_death_15": None, + }, + { + "patient_id": "2", + "date": date(2022, 1, 1), + "underlying_cause_of_death": "xyz", + "cause_of_death_01": "abc", + "cause_of_death_02": "def", + "cause_of_death_03": None, + "cause_of_death_04": None, + "cause_of_death_05": None, + "cause_of_death_06": None, + "cause_of_death_07": None, + "cause_of_death_08": None, + "cause_of_death_09": None, + "cause_of_death_10": None, + "cause_of_death_11": None, + "cause_of_death_12": None, + "cause_of_death_13": None, + "cause_of_death_14": None, + "cause_of_death_15": None, + }, + { + "patient_id": "3", + "date": date(2022, 1, 1), + "underlying_cause_of_death": "abc", + "cause_of_death_01": "abc", + "cause_of_death_02": "def", + "cause_of_death_03": None, + "cause_of_death_04": None, + "cause_of_death_05": None, + "cause_of_death_06": None, + "cause_of_death_07": None, + "cause_of_death_08": None, + "cause_of_death_09": None, + "cause_of_death_10": None, + "cause_of_death_11": None, + "cause_of_death_12": None, + "cause_of_death_13": None, + "cause_of_death_14": None, + "cause_of_death_15": None, + }, + ] + + +@register_test_for(emis_raw.ons_deaths) +def test_ons_deaths_raw(select_all_emis): + results = select_all_emis( + PatientAllOrgsV2(registration_id="1", nhs_no="nhs1"), + PatientAllOrgsV2(registration_id="2", nhs_no="nhs2"), + PatientAllOrgsV2(registration_id="3", nhs_no="nhs3"), + # duplicate registration_id, patient omitted + PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"), + PatientAllOrgsV2(registration_id="4", nhs_no="nhs4"), + OnsView( + pseudonhsnumber="nhs1", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + # older upload date, ignored + OnsView( + pseudonhsnumber="nhs1", + upload_date="20220101", + reg_stat_dod="20210101", + icd10u="wxy", + icd10001="abc", + icd10002="def", + ), + # same patient, different date of death; earliest dod is selected + OnsView( + pseudonhsnumber="nhs2", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + OnsView( + pseudonhsnumber="nhs2", + upload_date="20230101", + reg_stat_dod="20220102", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + # same patient, same date of death; lexically smallest cause of death is selected + OnsView( + pseudonhsnumber="nhs3", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="abc", + icd10001="abc", + icd10002="def", + ), + OnsView( + pseudonhsnumber="nhs3", + upload_date="20230101", + reg_stat_dod="20220101", + icd10u="xyz", + icd10001="abc", + icd10002="def", + ), + ) + + # results include duplicates, but still omit earlier uploads and duplicate + # registrations + results_summary = [(result["patient_id"], result["date"]) for result in results] + assert results_summary == [ + ("1", date(2022, 1, 1)), + ("2", date(2022, 1, 1)), + ("2", date(2022, 1, 2)), + ("3", date(2022, 1, 1)), + ("3", date(2022, 1, 1)), + ] + + +@register_test_for(emis.patients) +def test_patients(select_all_emis): + results = select_all_emis( + PatientAllOrgsV2( + registration_id="1", + date_of_birth=date(2020, 1, 1), + gender=1, + hashed_organisation="1A2B3C", + registered_date=date(2021, 3, 1), + rural_urban=1, + imd_rank=500, + ), + # duplicate registration ids are ignored + PatientAllOrgsV2( + registration_id="2", + date_of_birth=date(2020, 1, 1), + gender=1, + hashed_organisation="1A2B3C", + registered_date=date(2021, 3, 1), + ), + PatientAllOrgsV2( + registration_id="2", + date_of_birth=date(2020, 1, 1), + gender=1, + hashed_organisation="1A2B3C", + registered_date=date(2021, 3, 1), + ), + PatientAllOrgsV2( + registration_id="3", + date_of_birth=date(1960, 1, 1), + date_of_death=date(2020, 1, 1), + gender=2, + hashed_organisation="1A2B3C", + registered_date=date(1960, 3, 1), + ), + PatientAllOrgsV2( + registration_id="4", + date_of_birth=date(2020, 1, 1), + gender=0, + hashed_organisation="1A2B3C", + registered_date=date(2021, 3, 1), + ), + PatientAllOrgsV2( + registration_id="5", + date_of_birth=date(1978, 10, 13), + gender=9, + hashed_organisation="1A2B3C", + registered_date=date(2021, 3, 1), + ), + ) + + expected = [ + { + "patient_id": "1", + "date_of_birth": date(2020, 1, 1), + "sex": "male", + "date_of_death": None, + "registration_start_date": date(2021, 3, 1), + "registration_end_date": None, + "practice_pseudo_id": "1A2B3C", + "rural_urban_classification": 1, + "imd_rounded": 500, + }, + { + "patient_id": "3", + "date_of_birth": date(1960, 1, 1), + "sex": "female", + "date_of_death": date(2020, 1, 1), + "registration_start_date": date(1960, 3, 1), + "registration_end_date": None, + "practice_pseudo_id": "1A2B3C", + "rural_urban_classification": None, + "imd_rounded": None, + }, + { + "patient_id": "4", + "date_of_birth": date(2020, 1, 1), + "sex": "unknown", + "date_of_death": None, + "registration_start_date": date(2021, 3, 1), + "registration_end_date": None, + "practice_pseudo_id": "1A2B3C", + "rural_urban_classification": None, + "imd_rounded": None, + }, + { + "patient_id": "5", + "date_of_birth": date(1978, 10, 13), + "sex": "unknown", + "date_of_death": None, + "registration_start_date": date(2021, 3, 1), + "registration_end_date": None, + "practice_pseudo_id": "1A2B3C", + "rural_urban_classification": None, + "imd_rounded": None, + }, + ] + assert results == expected + + +@register_test_for(emis.practice_registrations) +def test_practice_registrations(select_all_emis): + results = select_all_emis( + PatientAllOrgsV2( + registration_id="1", + hashed_organisation="1f", + registered_date=date(2021, 3, 1), + registration_end_date=date(2022, 4, 2), + ), + PatientAllOrgsV2( + registration_id="1", + hashed_organisation="10A", + registered_date=date(2022, 4, 3), + registration_end_date=None, + ), + PatientAllOrgsV2( + registration_id="2", + hashed_organisation="123ABC", + registered_date=date(2000, 1, 1), + registration_end_date=date(2020, 1, 1), + ), + ) + + expected = [ + { + "patient_id": "1", + "start_date": date(2021, 3, 1), + "end_date": date(2022, 4, 2), + # The core `practice_registrations` table defines `practice_pseudo_id` as an + # int, so we have to convert from hex strings to ints here + "practice_pseudo_id": 31, + }, + { + "patient_id": "1", + "start_date": date(2022, 4, 3), + "end_date": None, + "practice_pseudo_id": 266, + }, + { + "patient_id": "2", + "start_date": date(2000, 1, 1), + "end_date": date(2020, 1, 1), + "practice_pseudo_id": 1194684, + }, + ] + + # Trino doesn't return results in a stable order + def sort(lst): + return sorted(lst, key=lambda i: (i["patient_id"], i["practice_pseudo_id"])) + + assert sort(results) == sort(expected) + + +@register_test_for(emis.vaccinations) +def test_vaccinations(select_all_emis): + results = select_all_emis( + PatientAllOrgsV2(registration_id="1"), + PatientAllOrgsV2(registration_id="2"), + PatientAllOrgsV2(registration_id="3"), + ImmunisationAllOrgsV2( + registration_id="1", + effective_date=datetime(2020, 10, 20, 14, 30, 5), + snomed_concept_id=123, + ), + ImmunisationAllOrgsV2( + registration_id="2", + effective_date=datetime(2021, 3, 23, 23, 30, 5), + snomed_concept_id=456, + ), + ImmunisationAllOrgsV2( + registration_id="2", + effective_date=datetime(2022, 1, 15, 12, 30, 5), + snomed_concept_id=567, + ), + ) + assert results == [ + { + "patient_id": "1", + "date": date(2020, 10, 20), + "procedure_code": "123", + }, + { + "patient_id": "2", + "date": date(2021, 3, 23), + "procedure_code": "456", + }, + { + "patient_id": "2", + "date": date(2022, 1, 15), + "procedure_code": "567", + }, + ] + + +def test_registered_tests_are_exhaustive(): + assert_tests_exhaustive(EMISBackend()) + + +def test_generated_table_includes_organisation_hash(trino_database): + # This tests that EMIS's generated inline and temporary tables include a column + # "hashed_organisation", where every row's value is the value of the + # EMIS_ORGANISATION_HASH environment variable + ORG_HASH = "testing_123" + + # Note that currently inline data tables always make patient_id an integer + # so in this test, our patient ids from the backend DB are coerced to ints + # In reality, this means inline tables won't be able to handle real EMIS + # data (where patient ids are strings) but this will be dealt with + # later + # https://github.com/opensafely-core/ehrql/issues/743 + inline_data = [ + (1, 100), + (2, 200), + ] + + @table_from_rows(inline_data) + class t(PatientFrame): + n = Series(int) + + dataset = create_dataset() + dataset.define_population(t.exists_for_patient()) + dataset.n = t.n + + backend = EMISBackend(config={"EMIS_ORGANISATION_HASH": ORG_HASH}) + query_engine = backend.query_engine_class( + trino_database.host_url(), + backend=backend, + ) + + # Monkey patch on our own `execute_query_no_results` method which records the contents of + # generated tables + orig_execute_query = query_engine.execute_query_no_results + found_tables = {} + + def execute_query_no_results(connection, query, *args, **kwargs): + # Before we drop any inline or temporary tables we grab the contents of their + # `hashed_organisation` column (which also serves as a test that they _have_ + # such a column) + if match := re.search(r"DROP .+\b(\w+_(inline_data|tmp)_\w+)\b", str(query)): + table_name = match.group(1) + results = connection.execute( + sqlalchemy.text(f"SELECT hashed_organisation FROM {table_name}") + ) + found_tables[match] = [row[0] for row in results] + orig_execute_query(connection, query, *args, **kwargs) + + query_engine.execute_query_no_results = execute_query_no_results + + # Consume the results to execute all queries + for table in query_engine.get_results_tables(dataset._compile()): + list(table) + + for results in found_tables.values(): + # Empty or single-row tables aren't really exercising the code properly so check + # we're not inadvertantly using those + assert len(results) > 1 + # Assert that the organisation hash appears in every row + assert results == [ORG_HASH] * len(results) + + # Check that we have examples of both the table types we're interested in + assert {match.group(2) for match in found_tables.keys()} == {"inline_data", "tmp"}