--- a +++ b/tests/unit/test_serializer.py @@ -0,0 +1,140 @@ +import datetime +import inspect +from pathlib import Path + +import pytest + +import ehrql.tables +from ehrql import INTERVAL, create_measures, years +from ehrql.file_formats import ( + FILE_FORMATS, + read_rows, + write_rows, +) +from ehrql.query_language import ( + DummyDataConfig, + create_dataset, + get_tables_from_namespace, +) +from ehrql.query_model.column_specs import ColumnSpec +from ehrql.serializer import SerializerError, deserialize, serialize +from ehrql.tables.core import clinical_events, patients +from ehrql.utils.module_utils import get_submodules + + +def as_query_model(query_lang_expr): + return query_lang_expr._qm_node + + +def define_measure(*args, **kwargs): + measures = create_measures() + measures.define_measure(*args, **kwargs) + return list(measures) + + +def get_all_tables(): + for module in get_submodules(ehrql.tables): + yield from ( + as_query_model(frame) for _, frame in get_tables_from_namespace(module) + ) + + +@pytest.mark.parametrize( + "value", + [ + # Primitive types + None, + True, + 5, + 0.5, + "foo", + datetime.date(2023, 10, 2), + # Container types + (1, 2, 3), + frozenset([1, 2, 3]), + {"foo": "bar"}, + # Dicts with non-string keys + {("foo", 1): ("bar", 2)}, + # Types + int, + datetime.date, + # Misc stuff + DummyDataConfig(population_size=10), + # Basic query model structures + as_query_model( + clinical_events.where( + clinical_events.date > patients.date_of_birth + years(10) + ) + .sort_by(clinical_events.date) + .first_for_patient() + .numeric_value + ), + # Basic measures + define_measure( + "test_measure", + numerator=clinical_events.where( + clinical_events.date.is_during(INTERVAL) + ).count_for_patient(), + denominator=patients.exists_for_patient(), + group_by=dict(sex=patients.sex), + intervals=years(3).starting_on("2020-01-01"), + ), + # Test that we can serialize every table in every schema + *get_all_tables(), + ], +) +def test_roundtrip(value): + assert value == deserialize(serialize(value), root_dir=Path.cwd()) + + +def test_dummy_data_config_roundtrip(): + dataset = create_dataset() + kwargs = dict( + population_size=100, + legacy=False, + timeout=100, + additional_population_constraint=(patients.date_of_birth < "2000-01-01"), + ) + dataset.configure_dummy_data(**kwargs) + config = dataset.dummy_data_config + assert config == deserialize(serialize(config), root_dir=Path.cwd()) + # Fail if we add new arguments to `configure_dummy_data` but don't exercise them + # here + assert set(kwargs) == set( + inspect.signature(dataset.configure_dummy_data).parameters + ) + + +# Fixture which generates a rows reader instance for every format we support +@pytest.fixture(params=list(FILE_FORMATS.keys())) +def rows_reader(request, tmp_path): + specs = { + "patient_id": ColumnSpec(int, nullable=False), + "b": ColumnSpec(bool), + "i": ColumnSpec(int, min_value=10, max_value=20), + "c": ColumnSpec(str, categories=("A", "B")), + } + + data = [ + (123, True, 10, "A"), + (456, None, 15, "B"), + (789, False, 20, "A"), + ] + extension = request.param + filename = tmp_path / f"some_file{extension}" + write_rows(filename, data, specs) + yield read_rows(filename, specs) + + +def test_roundtrip_rows_reader(rows_reader): + parent_dir = rows_reader.filename.parent + roundtripped = deserialize(serialize(rows_reader), root_dir=parent_dir) + assert roundtripped is not rows_reader + assert roundtripped == rows_reader + assert list(roundtripped) == list(rows_reader) + + +def test_rows_reader_cannot_be_deserialized_outside_of_root_dir(rows_reader): + serialized = serialize(rows_reader) + with pytest.raises(SerializerError, match="is not contained within the directory"): + deserialize(serialized, root_dir=Path("/some/path"))