[e988c2]: / tests / unit / test_serializer.py

Download this file

141 lines (121 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import datetime
import inspect
from pathlib import Path
import pytest
import ehrql.tables
from ehrql import INTERVAL, create_measures, years
from ehrql.file_formats import (
FILE_FORMATS,
read_rows,
write_rows,
)
from ehrql.query_language import (
DummyDataConfig,
create_dataset,
get_tables_from_namespace,
)
from ehrql.query_model.column_specs import ColumnSpec
from ehrql.serializer import SerializerError, deserialize, serialize
from ehrql.tables.core import clinical_events, patients
from ehrql.utils.module_utils import get_submodules
def as_query_model(query_lang_expr):
return query_lang_expr._qm_node
def define_measure(*args, **kwargs):
measures = create_measures()
measures.define_measure(*args, **kwargs)
return list(measures)
def get_all_tables():
for module in get_submodules(ehrql.tables):
yield from (
as_query_model(frame) for _, frame in get_tables_from_namespace(module)
)
@pytest.mark.parametrize(
"value",
[
# Primitive types
None,
True,
5,
0.5,
"foo",
datetime.date(2023, 10, 2),
# Container types
(1, 2, 3),
frozenset([1, 2, 3]),
{"foo": "bar"},
# Dicts with non-string keys
{("foo", 1): ("bar", 2)},
# Types
int,
datetime.date,
# Misc stuff
DummyDataConfig(population_size=10),
# Basic query model structures
as_query_model(
clinical_events.where(
clinical_events.date > patients.date_of_birth + years(10)
)
.sort_by(clinical_events.date)
.first_for_patient()
.numeric_value
),
# Basic measures
define_measure(
"test_measure",
numerator=clinical_events.where(
clinical_events.date.is_during(INTERVAL)
).count_for_patient(),
denominator=patients.exists_for_patient(),
group_by=dict(sex=patients.sex),
intervals=years(3).starting_on("2020-01-01"),
),
# Test that we can serialize every table in every schema
*get_all_tables(),
],
)
def test_roundtrip(value):
assert value == deserialize(serialize(value), root_dir=Path.cwd())
def test_dummy_data_config_roundtrip():
dataset = create_dataset()
kwargs = dict(
population_size=100,
legacy=False,
timeout=100,
additional_population_constraint=(patients.date_of_birth < "2000-01-01"),
)
dataset.configure_dummy_data(**kwargs)
config = dataset.dummy_data_config
assert config == deserialize(serialize(config), root_dir=Path.cwd())
# Fail if we add new arguments to `configure_dummy_data` but don't exercise them
# here
assert set(kwargs) == set(
inspect.signature(dataset.configure_dummy_data).parameters
)
# Fixture which generates a rows reader instance for every format we support
@pytest.fixture(params=list(FILE_FORMATS.keys()))
def rows_reader(request, tmp_path):
specs = {
"patient_id": ColumnSpec(int, nullable=False),
"b": ColumnSpec(bool),
"i": ColumnSpec(int, min_value=10, max_value=20),
"c": ColumnSpec(str, categories=("A", "B")),
}
data = [
(123, True, 10, "A"),
(456, None, 15, "B"),
(789, False, 20, "A"),
]
extension = request.param
filename = tmp_path / f"some_file{extension}"
write_rows(filename, data, specs)
yield read_rows(filename, specs)
def test_roundtrip_rows_reader(rows_reader):
parent_dir = rows_reader.filename.parent
roundtripped = deserialize(serialize(rows_reader), root_dir=parent_dir)
assert roundtripped is not rows_reader
assert roundtripped == rows_reader
assert list(roundtripped) == list(rows_reader)
def test_rows_reader_cannot_be_deserialized_outside_of_root_dir(rows_reader):
serialized = serialize(rows_reader)
with pytest.raises(SerializerError, match="is not contained within the directory"):
deserialize(serialized, root_dir=Path("/some/path"))