--- a +++ b/tests/spec/conftest.py @@ -0,0 +1,129 @@ +import datetime +import re + +import pytest + +from ehrql import Dataset + + +@pytest.fixture(params=["execute", "dump_sql"]) +def spec_test(request, engine): + # Test that we can insert the data, run the query, and get the expected results + def run_test_execute(table_data, series, expected_results, population=None): + # Populate database tables. + engine.populate( + { + table: parse_table(table._qm_node.schema, s) + for table, s in table_data.items() + } + ) + + # Create a Dataset with the specified population and a single variable which is + # the series under test. + dataset = make_dataset(table_data, population) + dataset.v = series + + # If we're comparing floats then we want only approximate equality to account + # for rounding differences + if series._type is float: + expected_results = pytest.approx(expected_results, rel=1e-5) + + # Extract data, and check it's as expected. + results = [(r["patient_id"], r["v"]) for r in engine.extract(dataset)] + results_dict = dict(results) + assert len(results) == len(results_dict), "Duplicate patient IDs found" + assert results_dict == expected_results + + # Assert types are as expected + for patient_id, value in results_dict.items(): + if value is not None: + assert isinstance(value, series._type), ( + f"Expected {series._type} got {type(value)} in " + f"result {{{patient_id}: {value}}}" + ) + + # Test that we can generate SQL with literal parmeters for debugging purposes + def run_test_dump_sql(table_data, series, expected_results, population=None): + # Create a Dataset with the specified population and a single variable which is + # the series under test. + dataset = make_dataset(table_data, population) + dataset.v = series + + # Check that we can generate SQL without error + assert engine.dump_dataset_sql(dataset) + + mode = request.param + + if mode == "execute": + return run_test_execute + elif mode == "dump_sql": + if engine.name == "in_memory": + pytest.skip("in_memory engine produces no SQL") + return run_test_dump_sql + else: + assert False + + +def make_dataset(table_data, population=None): + # To reduce noise in the tests we provide a default population which contains all + # patients in any tables referenced in the data + if population is None: + population = False + for table in table_data.keys(): + population = table.exists_for_patient() | population + dataset = Dataset() + dataset.define_population(population) + return dataset + + +def parse_table(schema, s): + """Parse string containing table data, returning list of dicts. + + See test_conftest.py for examples. + """ + + header, _, *lines = s.strip().splitlines() + col_names = [token.strip() for token in header.split("|")] + col_names[0] = "patient_id" + column_types = dict( + patient_id=int, **{name: type_ for name, type_ in schema.column_types} + ) + rows = [parse_row(column_types, col_names, line) for line in lines] + return rows + + +def parse_row(column_types, col_names, line): + """Parse string containing row data, returning list of values. + + See test_conftest.py for examples. + """ + + # Regex splits on any '|' character, as long as it's not adjacent + # to another '|' character using look-ahead and look-behind. This + # is to allow '||' to appear as content within a field, currently + # just for the all_diagnoses and all_procedures fields in apcs + return { + col_name: parse_value(column_types[col_name], token.strip()) + for col_name, token in zip(col_names, re.split(r"(?<!\|)\|(?!\|)", line)) + } + + +def parse_value(type_, value): + """Parse string returning value of correct type for column. + + An empty string indicates a null value. + """ + if not value: + return None + + if hasattr(type_, "_primitive_type"): + type_ = type_._primitive_type() + + if type_ is bool: + parse = lambda v: {"T": True, "F": False}[v] # noqa E731 + elif type_ == datetime.date: + parse = datetime.date.fromisoformat + else: + parse = type_ + + return parse(value)