[e988c2]: / tests / spec / conftest.py

Download this file

130 lines (101 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import datetime
import re
import pytest
from ehrql import Dataset
@pytest.fixture(params=["execute", "dump_sql"])
def spec_test(request, engine):
# Test that we can insert the data, run the query, and get the expected results
def run_test_execute(table_data, series, expected_results, population=None):
# Populate database tables.
engine.populate(
{
table: parse_table(table._qm_node.schema, s)
for table, s in table_data.items()
}
)
# Create a Dataset with the specified population and a single variable which is
# the series under test.
dataset = make_dataset(table_data, population)
dataset.v = series
# If we're comparing floats then we want only approximate equality to account
# for rounding differences
if series._type is float:
expected_results = pytest.approx(expected_results, rel=1e-5)
# Extract data, and check it's as expected.
results = [(r["patient_id"], r["v"]) for r in engine.extract(dataset)]
results_dict = dict(results)
assert len(results) == len(results_dict), "Duplicate patient IDs found"
assert results_dict == expected_results
# Assert types are as expected
for patient_id, value in results_dict.items():
if value is not None:
assert isinstance(value, series._type), (
f"Expected {series._type} got {type(value)} in "
f"result {{{patient_id}: {value}}}"
)
# Test that we can generate SQL with literal parmeters for debugging purposes
def run_test_dump_sql(table_data, series, expected_results, population=None):
# Create a Dataset with the specified population and a single variable which is
# the series under test.
dataset = make_dataset(table_data, population)
dataset.v = series
# Check that we can generate SQL without error
assert engine.dump_dataset_sql(dataset)
mode = request.param
if mode == "execute":
return run_test_execute
elif mode == "dump_sql":
if engine.name == "in_memory":
pytest.skip("in_memory engine produces no SQL")
return run_test_dump_sql
else:
assert False
def make_dataset(table_data, population=None):
# To reduce noise in the tests we provide a default population which contains all
# patients in any tables referenced in the data
if population is None:
population = False
for table in table_data.keys():
population = table.exists_for_patient() | population
dataset = Dataset()
dataset.define_population(population)
return dataset
def parse_table(schema, s):
"""Parse string containing table data, returning list of dicts.
See test_conftest.py for examples.
"""
header, _, *lines = s.strip().splitlines()
col_names = [token.strip() for token in header.split("|")]
col_names[0] = "patient_id"
column_types = dict(
patient_id=int, **{name: type_ for name, type_ in schema.column_types}
)
rows = [parse_row(column_types, col_names, line) for line in lines]
return rows
def parse_row(column_types, col_names, line):
"""Parse string containing row data, returning list of values.
See test_conftest.py for examples.
"""
# Regex splits on any '|' character, as long as it's not adjacent
# to another '|' character using look-ahead and look-behind. This
# is to allow '||' to appear as content within a field, currently
# just for the all_diagnoses and all_procedures fields in apcs
return {
col_name: parse_value(column_types[col_name], token.strip())
for col_name, token in zip(col_names, re.split(r"(?<!\|)\|(?!\|)", line))
}
def parse_value(type_, value):
"""Parse string returning value of correct type for column.
An empty string indicates a null value.
"""
if not value:
return None
if hasattr(type_, "_primitive_type"):
type_ = type_._primitive_type()
if type_ is bool:
parse = lambda v: {"T": True, "F": False}[v] # noqa E731
elif type_ == datetime.date:
parse = datetime.date.fromisoformat
else:
parse = type_
return parse(value)