Diff of /tests/spec/conftest.py [000000] .. [e988c2]

Switch to unified view

a b/tests/spec/conftest.py
1
import datetime
2
import re
3
4
import pytest
5
6
from ehrql import Dataset
7
8
9
@pytest.fixture(params=["execute", "dump_sql"])
10
def spec_test(request, engine):
11
    # Test that we can insert the data, run the query, and get the expected results
12
    def run_test_execute(table_data, series, expected_results, population=None):
13
        # Populate database tables.
14
        engine.populate(
15
            {
16
                table: parse_table(table._qm_node.schema, s)
17
                for table, s in table_data.items()
18
            }
19
        )
20
21
        # Create a Dataset with the specified population and a single variable which is
22
        # the series under test.
23
        dataset = make_dataset(table_data, population)
24
        dataset.v = series
25
26
        # If we're comparing floats then we want only approximate equality to account
27
        # for rounding differences
28
        if series._type is float:
29
            expected_results = pytest.approx(expected_results, rel=1e-5)
30
31
        # Extract data, and check it's as expected.
32
        results = [(r["patient_id"], r["v"]) for r in engine.extract(dataset)]
33
        results_dict = dict(results)
34
        assert len(results) == len(results_dict), "Duplicate patient IDs found"
35
        assert results_dict == expected_results
36
37
        # Assert types are as expected
38
        for patient_id, value in results_dict.items():
39
            if value is not None:
40
                assert isinstance(value, series._type), (
41
                    f"Expected {series._type} got {type(value)} in "
42
                    f"result {{{patient_id}: {value}}}"
43
                )
44
45
    # Test that we can generate SQL with literal parmeters for debugging purposes
46
    def run_test_dump_sql(table_data, series, expected_results, population=None):
47
        # Create a Dataset with the specified population and a single variable which is
48
        # the series under test.
49
        dataset = make_dataset(table_data, population)
50
        dataset.v = series
51
52
        # Check that we can generate SQL without error
53
        assert engine.dump_dataset_sql(dataset)
54
55
    mode = request.param
56
57
    if mode == "execute":
58
        return run_test_execute
59
    elif mode == "dump_sql":
60
        if engine.name == "in_memory":
61
            pytest.skip("in_memory engine produces no SQL")
62
        return run_test_dump_sql
63
    else:
64
        assert False
65
66
67
def make_dataset(table_data, population=None):
68
    # To reduce noise in the tests we provide a default population which contains all
69
    # patients in any tables referenced in the data
70
    if population is None:
71
        population = False
72
        for table in table_data.keys():
73
            population = table.exists_for_patient() | population
74
    dataset = Dataset()
75
    dataset.define_population(population)
76
    return dataset
77
78
79
def parse_table(schema, s):
80
    """Parse string containing table data, returning list of dicts.
81
82
    See test_conftest.py for examples.
83
    """
84
85
    header, _, *lines = s.strip().splitlines()
86
    col_names = [token.strip() for token in header.split("|")]
87
    col_names[0] = "patient_id"
88
    column_types = dict(
89
        patient_id=int, **{name: type_ for name, type_ in schema.column_types}
90
    )
91
    rows = [parse_row(column_types, col_names, line) for line in lines]
92
    return rows
93
94
95
def parse_row(column_types, col_names, line):
96
    """Parse string containing row data, returning list of values.
97
98
    See test_conftest.py for examples.
99
    """
100
101
    # Regex splits on any '|' character, as long as it's not adjacent
102
    # to another '|' character using look-ahead and look-behind. This
103
    # is to allow '||' to appear as content within a field, currently
104
    # just for the all_diagnoses and all_procedures fields in apcs
105
    return {
106
        col_name: parse_value(column_types[col_name], token.strip())
107
        for col_name, token in zip(col_names, re.split(r"(?<!\|)\|(?!\|)", line))
108
    }
109
110
111
def parse_value(type_, value):
112
    """Parse string returning value of correct type for column.
113
114
    An empty string indicates a null value.
115
    """
116
    if not value:
117
        return None
118
119
    if hasattr(type_, "_primitive_type"):
120
        type_ = type_._primitive_type()
121
122
    if type_ is bool:
123
        parse = lambda v: {"T": True, "F": False}[v]  # noqa E731
124
    elif type_ == datetime.date:
125
        parse = datetime.date.fromisoformat
126
    else:
127
        parse = type_
128
129
    return parse(value)