--- a +++ b/tests/functional/test_generate_dataset.py @@ -0,0 +1,455 @@ +from datetime import date, datetime + +import pytest + +from ehrql.file_formats import FILE_FORMATS +from ehrql.tables import core +from tests.lib.file_utils import read_file_as_dicts +from tests.lib.inspect_utils import function_body_as_string +from tests.lib.tpp_schema import AllowedPatientsWithTypeOneDissent, Patient + + +@function_body_as_string +def trivial_dataset_definition(): + from ehrql import create_dataset + from ehrql.tables.tpp import patients + + dataset = create_dataset() + year = patients.date_of_birth.year + dataset.define_population(year >= 1940) + dataset.year = year + + dataset.configure_dummy_data( + population_size=10, + additional_population_constraint=patients.date_of_death.is_null(), + ) + + +@function_body_as_string +def trivial_dataset_definition_legacy_dummy_data(): + from ehrql import create_dataset + from ehrql.tables.tpp import patients + + dataset = create_dataset() + year = patients.date_of_birth.year + dataset.define_population(year >= 1940) + dataset.year = year + + dataset.configure_dummy_data(population_size=10, legacy=True) + + +@function_body_as_string +def parameterised_dataset_definition(): + from argparse import ArgumentParser + + from ehrql import create_dataset + from ehrql.tables.tpp import patients + + parser = ArgumentParser() + parser.add_argument("--year", type=int) + args = parser.parse_args() + + dataset = create_dataset() + year = patients.date_of_birth.year + dataset.define_population(year >= args.year) + dataset.year = year + + +@pytest.mark.parametrize("extension", list(FILE_FORMATS.keys())) +def test_generate_dataset_with_tpp_backend( + call_cli, tmp_path, mssql_database, extension +): + mssql_database.setup( + Patient(Patient_ID=1, DateOfBirth=datetime(1934, 5, 5)), + AllowedPatientsWithTypeOneDissent(Patient_ID=1), + Patient(Patient_ID=2, DateOfBirth=datetime(1943, 5, 5)), + AllowedPatientsWithTypeOneDissent(Patient_ID=2), + Patient(Patient_ID=3, DateOfBirth=datetime(1999, 5, 5)), + AllowedPatientsWithTypeOneDissent(Patient_ID=3), + ) + + output_path = tmp_path / f"results.{extension}" + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(trivial_dataset_definition) + + call_cli( + "generate-dataset", + dataset_definition_path, + "--output", + output_path, + "--backend", + "tpp", + "--dsn", + mssql_database.host_url(), + ) + results = read_file_as_dicts(output_path) + + expected = [1943, 1999] + if extension in (".csv", ".csv.gz"): + expected = [str(v) for v in expected] + + assert len(results) == len(expected) + assert {r["year"] for r in results} == set(expected) + + +def test_parameterised_dataset_definition(call_cli, tmp_path, mssql_database): + mssql_database.setup( + Patient(Patient_ID=1, DateOfBirth=datetime(1934, 5, 5)), + AllowedPatientsWithTypeOneDissent(Patient_ID=1), + Patient(Patient_ID=2, DateOfBirth=datetime(1943, 5, 5)), + AllowedPatientsWithTypeOneDissent(Patient_ID=2), + Patient(Patient_ID=3, DateOfBirth=datetime(1999, 5, 5)), + AllowedPatientsWithTypeOneDissent(Patient_ID=3), + ) + + output_path = tmp_path / "results.csv" + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(parameterised_dataset_definition) + + call_cli( + "generate-dataset", + dataset_definition_path, + "--output", + output_path, + "--backend", + "tpp", + "--dsn", + mssql_database.host_url(), + "--", + "--year", + "1940", + ) + results = read_file_as_dicts(output_path) + + expected = ["1943", "1999"] + + assert len(results) == len(expected) + assert {r["year"] for r in results} == set(expected) + + +def test_parameterised_dataset_definition_with_bad_param(tmp_path, call_cli): + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(parameterised_dataset_definition) + + with pytest.raises(SystemExit): + call_cli( + "generate-dataset", + dataset_definition_path, + "--", + "--ear", + "1940", + ) + assert ( + "dataset_definition.py: error: unrecognized arguments: --ear 1940" + in call_cli.readouterr().err + ) + + +def test_generate_dataset_with_database_error(tmp_path, call_cli, mssql_database): + mssql_database.setup( + Patient(Patient_ID=1, DateOfBirth=datetime(1934, 5, 5)), + AllowedPatientsWithTypeOneDissent(Patient_ID=1), + ) + + @function_body_as_string + def database_operational_error_dataset_definition(): + from ehrql import create_dataset, years + from ehrql.tables.core import patients + + dataset = create_dataset() + dataset.define_population(patients.date_of_birth.year >= 1900) + dataset.extended_dob = patients.date_of_birth + years(9999) + + dataset.configure_dummy_data(population_size=10) + + # This dataset definition triggers an OperationalError by implementing date + # arithmetic that results in an out of bounds date (after 9999-12-31) + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(database_operational_error_dataset_definition) + + with pytest.raises(SystemExit) as err: + call_cli( + "generate-dataset", + dataset_definition_path, + "--backend", + "tpp", + "--dsn", + mssql_database.host_url(), + ) + assert err.value.code == 5 + + +def test_validate_dummy_data_happy_path(tmp_path, call_cli): + dummy_data_file = tmp_path / "dummy.csv" + dummy_data = "patient_id,year\n1,1971\n2,1992" + dummy_data_file.write_text(dummy_data) + + output_path = tmp_path / "results.csv" + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(trivial_dataset_definition) + + call_cli( + "generate-dataset", + dataset_definition_path, + "--output", + output_path, + "--dummy-data-file", + dummy_data_file, + ) + results = read_file_as_dicts(output_path) + + assert results == [ + {"patient_id": "1", "year": "1971"}, + {"patient_id": "2", "year": "1992"}, + ] + + +def test_validate_dummy_data_error_path(tmp_path, call_cli): + dummy_data_file = tmp_path / "dummy.csv" + dummy_data = "patient_id,year\n1,1971\n2,foo" + dummy_data_file.write_text(dummy_data) + + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(trivial_dataset_definition) + + with pytest.raises(SystemExit): + call_cli( + "generate-dataset", + dataset_definition_path, + "--dummy-data-file", + dummy_data_file, + ) + assert "invalid literal for int" in call_cli.readouterr().err + + +@pytest.mark.parametrize( + "dataset_definition_fixture", + ( + trivial_dataset_definition, + trivial_dataset_definition_legacy_dummy_data, + ), +) +def test_generate_dummy_data(tmp_path, call_cli, dataset_definition_fixture): + output_path = tmp_path / "results.csv" + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(dataset_definition_fixture) + + call_cli( + "generate-dataset", + dataset_definition_path, + "--output", + output_path, + ) + lines = output_path.read_text().splitlines() + + assert lines[0] == "patient_id,year" + assert len(lines) == 11 # 1 header, 10 rows + + +def test_generate_dummy_data_with_dummy_tables(tmp_path, call_cli): + dummy_tables_path = tmp_path / "dummy_tables" + dummy_tables_path.mkdir() + dummy_tables_path.joinpath("patients.csv").write_text( + "patient_id,date_of_birth\n8,1985-10-20\n9,1995-05-10" + ) + + output_path = tmp_path / "results.csv" + + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(trivial_dataset_definition) + + call_cli( + "generate-dataset", + dataset_definition_path, + "--output", + output_path, + "--dummy-tables", + dummy_tables_path, + ) + results = read_file_as_dicts(output_path) + + assert results == [ + {"patient_id": "8", "year": "1985"}, + {"patient_id": "9", "year": "1995"}, + ] + + +def test_generate_dataset_disallows_reading_file_outside_working_directory( + tmp_path, monkeypatch, call_cli +): + csv_file = tmp_path / "file.csv" + csv_file.write_text("patient_id,i\n1,10\n2,20") + + @function_body_as_string + def code(): + from ehrql import create_dataset + from ehrql.tables import PatientFrame, Series, table_from_file + + @table_from_file("<CSV_FILE>") + class test_table(PatientFrame): + i = Series(int) + + dataset = create_dataset() + dataset.define_population(test_table.exists_for_patient()) + dataset.configure_dummy_data(population_size=2) + dataset.i = test_table.i + + code = code.replace('"<CSV_FILE>"', repr(str(csv_file))) + + dataset_file = tmp_path / "sub_dir" / "dataset_definition.py" + dataset_file.parent.mkdir(parents=True, exist_ok=True) + dataset_file.write_text(code) + + monkeypatch.chdir(dataset_file.parent) + with pytest.raises(Exception) as e: + call_cli("generate-dataset", dataset_file) + assert "is not contained within the directory" in str(e.value) + + +@pytest.mark.parametrize("legacy", [True, False]) +def test_generate_dataset_passes_dummy_data_config(call_cli, tmp_path, caplog, legacy): + @function_body_as_string + def code(): + from ehrql import create_dataset + from ehrql.tables.core import patients + + dataset = create_dataset() + dataset.define_population(patients.exists_for_patient()) + dataset.sex = patients.sex + + dataset.configure_dummy_data(population_size=2, timeout=3, **{}) + + code = code.replace("**{}", "legacy=True" if legacy else "") + dataset_file = tmp_path / "dataset_definition.py" + dataset_file.write_text(code) + + call_cli( + "generate-dataset", + dataset_file, + "--output", + tmp_path / "output.csv", + ) + + logs = caplog.text + assert "Attempting to generate 2 matching patients" in logs + assert "timeout: 3s" in logs + if legacy: + assert "Using legacy dummy data generation" in logs + else: + assert "Using next generation dummy data generation" in logs + + +def test_generate_dataset_with_test_data_file(call_cli, tmp_path): + @function_body_as_string + def dataset_definition_with_tests(): + from ehrql import create_dataset + from ehrql.tables.core import patients + + dataset = create_dataset() + dataset.define_population(patients.sex == "female") + + test_data = { # noqa: F841 + 1: { + "patients": {"sex": "male"}, + "expected_in_population": False, + }, + 2: { + "patients": {"sex": "female"}, + "expected_in_population": True, + "expected_columns": {}, + }, + } + + test_data_file = tmp_path / "dataset_definition.py" + test_data_file.write_text(dataset_definition_with_tests) + output_file = tmp_path / "output.csv" + + captured = call_cli( + "generate-dataset", + test_data_file, + "--output", + output_file, + "--test-data-file", + test_data_file, + ) + + # Check that the assurance tests were invoked + assert "All OK!" in captured.out + # Check we also generated some output + assert len(output_file.read_text()) > 0 + + +def test_generate_dataset_with_event_level_data(sqlite_engine, call_cli, tmp_path): + engine = sqlite_engine + extension = "csv" + + engine.populate( + { + core.patients: [ + {"patient_id": 1, "date_of_birth": date(1980, 1, 1)}, + {"patient_id": 2, "date_of_birth": date(1990, 1, 1)}, + {"patient_id": 3, "date_of_birth": date(2000, 1, 1)}, + ], + core.clinical_events: [ + {"patient_id": 1, "date": date(2020, 1, 1), "snomedct_code": "123456"}, + {"patient_id": 1, "date": date(2020, 2, 1), "snomedct_code": "123456"}, + {"patient_id": 1, "date": date(2020, 2, 1), "snomedct_code": "923456"}, + {"patient_id": 2, "date": date(2020, 3, 1), "snomedct_code": "123456"}, + {"patient_id": 2, "date": date(2020, 4, 1), "snomedct_code": "123457"}, + {"patient_id": 3, "date": date(2020, 5, 1), "snomedct_code": "123456"}, + {"patient_id": 3, "date": date(2020, 6, 1), "snomedct_code": "123457"}, + {"patient_id": 3, "date": date(2020, 7, 1), "snomedct_code": "123456"}, + ], + } + ) + + @function_body_as_string + def dataset_definition(): + from ehrql import create_dataset + from ehrql.tables.core import clinical_events, patients + + dataset = create_dataset() + dataset.define_population(patients.date_of_birth.year != 1990) + dataset.dob = patients.date_of_birth + events_1 = clinical_events.where( + clinical_events.snomedct_code.is_in(["123456", "123457"]) + ) + events_2 = clinical_events.where( + clinical_events.snomedct_code.is_in(["923456"]) + ) + dataset.add_event_table( + "events_1", date=events_1.date, code=events_1.snomedct_code + ) + dataset.add_event_table( + "events_2", date=events_2.date, code=events_2.snomedct_code + ) + + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(dataset_definition) + output_path = tmp_path / "results" + + call_cli( + "generate-dataset", + dataset_definition_path, + "--output", + f"{output_path}:{extension}", + "--dsn", + engine.database.host_url(), + "--query-engine", + engine.name, + ) + + assert read_file_as_dicts(output_path / f"dataset.{extension}") == [ + {"patient_id": "1", "dob": "1980-01-01"}, + {"patient_id": "3", "dob": "2000-01-01"}, + ] + assert read_file_as_dicts(output_path / f"events_1.{extension}") == [ + {"patient_id": "1", "date": "2020-01-01", "code": "123456"}, + {"patient_id": "1", "date": "2020-02-01", "code": "123456"}, + {"patient_id": "3", "date": "2020-05-01", "code": "123456"}, + {"patient_id": "3", "date": "2020-06-01", "code": "123457"}, + {"patient_id": "3", "date": "2020-07-01", "code": "123456"}, + ] + assert read_file_as_dicts(output_path / f"events_2.{extension}") == [ + {"patient_id": "1", "date": "2020-02-01", "code": "923456"}, + ]