a b/tests/unit/backends/test_base.py
1
import datetime
2
3
import pytest
4
import sqlalchemy
5
6
from ehrql.backends.base import (
7
    DefaultSQLBackend,
8
    MappedTable,
9
    QueryTable,
10
    SQLBackend,
11
    ValidationError,
12
)
13
from ehrql.query_engines.base_sql import BaseSQLQueryEngine
14
from ehrql.query_model.nodes import Column, TableSchema
15
from ehrql.tables import PatientFrame, Series, table
16
17
18
class BackendFixture(SQLBackend):
19
    display_name = "Backend Fixture"
20
    query_engine_class = BaseSQLQueryEngine
21
    patient_join_column = "PatientId"
22
23
    patients = MappedTable(
24
        source="Patient",
25
        columns=dict(
26
            patient_id="PatID",
27
            date_of_birth="DateOfBirth",
28
        ),
29
    )
30
31
    events = MappedTable(
32
        source="events",
33
        columns=dict(
34
            date="date",
35
        ),
36
    )
37
38
    practice_registrations = QueryTable(
39
        "SELECT patient_id, date_start, date_end FROM some_table"
40
    )
41
42
    @QueryTable.from_function
43
    def positive_tests(self):
44
        table_name = self.config.get("table_name", "some_table")
45
        return f"SELECT patient_id, date FROM {table_name}"
46
47
48
def test_backend_registers_tables():
49
    """Test that a backend registers its table names"""
50
51
    assert set(BackendFixture.tables) == {
52
        "patients",
53
        "events",
54
        "practice_registrations",
55
        "positive_tests",
56
    }
57
58
59
def test_mapped_table_sql_with_modified_names():
60
    table = BackendFixture().get_table_expression(
61
        "patients",
62
        TableSchema(
63
            date_of_birth=Column(datetime.date),
64
        ),
65
    )
66
    sql = str(sqlalchemy.select(table.c.patient_id, table.c.date_of_birth))
67
    assert sql == 'SELECT "Patient"."PatID", "Patient"."DateOfBirth" \nFROM "Patient"'
68
69
70
def test_mapped_table_sql_with_matching_names():
71
    table = BackendFixture().get_table_expression(
72
        "events",
73
        TableSchema(
74
            date=Column(datetime.date),
75
        ),
76
    )
77
    sql = str(sqlalchemy.select(table.c.patient_id, table.c.date))
78
    assert sql == 'SELECT events."PatientId", events.date \nFROM events'
79
80
81
def test_query_table_sql():
82
    table = BackendFixture().get_table_expression(
83
        "practice_registrations",
84
        TableSchema(
85
            date_start=Column(datetime.date),
86
            date_end=Column(datetime.date),
87
        ),
88
    )
89
    sql = str(sqlalchemy.select(table.c.patient_id, table.c.date_start))
90
    assert sql == (
91
        "SELECT practice_registrations.patient_id, practice_registrations.date_start \n"
92
        "FROM (SELECT patient_id, date_start, date_end FROM some_table) AS "
93
        "practice_registrations"
94
    )
95
96
97
def test_query_table_from_function_sql():
98
    backend = BackendFixture(config={"table_name": "other_table"})
99
    table = backend.get_table_expression(
100
        "positive_tests",
101
        TableSchema(date=Column(datetime.date)),
102
    )
103
    assert str(table) == "SELECT patient_id, date FROM other_table"
104
105
106
def test_default_backend_sql():
107
    backend = DefaultSQLBackend(query_engine_class=BaseSQLQueryEngine)
108
    table = backend.get_table_expression(
109
        "some_table", TableSchema(i=Column(int), b=Column(bool))
110
    )
111
    sql = str(sqlalchemy.select(table.c.patient_id, table.c.i, table.c.b))
112
    assert sql == (
113
        "SELECT some_table.patient_id, some_table.i, some_table.b \nFROM some_table"
114
    )
115
116
117
# Use a class as a convenient namespace (`types.SimpleNamespace` would also work)
118
class Schema:
119
    @table
120
    class patients(PatientFrame):
121
        date_of_birth = Series(datetime.date)
122
123
124
def test_backend_definition_is_allowed_extra_tables_and_columns():
125
    class BackendFixture(SQLBackend):
126
        display_name = "Backend Fixture"
127
        query_engine_class = BaseSQLQueryEngine
128
        patient_join_column = "patient_id"
129
        implements = [Schema]
130
131
        patients = MappedTable(
132
            source="patient",
133
            columns=dict(date_of_birth="DoB", sex="sex"),
134
        )
135
        events = MappedTable(
136
            source="patient",
137
            columns=dict(date="date", code="code"),
138
        )
139
140
    assert BackendFixture
141
142
143
def test_backend_definition_accepts_query_table():
144
    class BackendFixture(SQLBackend):
145
        display_name = "Backend Fixture"
146
        query_engine_class = BaseSQLQueryEngine
147
        patient_join_column = "patient_id"
148
        implements = [Schema]
149
150
        patients = QueryTable(
151
            "SELECT patient_id, CAST(DoB AS date) AS date_of_birth FROM patients",
152
        )
153
154
    assert BackendFixture
155
156
157
def test_backend_definition_fails_if_missing_tables():
158
    with pytest.raises(ValidationError, match="does not implement table"):
159
160
        class BackendFixture(SQLBackend):
161
            display_name = "Backend Fixture"
162
            query_engine_class = BaseSQLQueryEngine
163
            patient_join_column = "patient_id"
164
            implements = [Schema]
165
166
            events = MappedTable(
167
                source="patient",
168
                columns=dict(date="date", code="code"),
169
            )
170
171
172
def test_backend_definition_fails_if_missing_column():
173
    with pytest.raises(ValidationError, match="missing columns"):
174
175
        class BackendFixture(SQLBackend):
176
            display_name = "Backend Fixture"
177
            query_engine_class = BaseSQLQueryEngine
178
            patient_join_column = "patient_id"
179
            implements = [Schema]
180
181
            patients = MappedTable(
182
                source="patient",
183
                columns=dict(sex="sex"),
184
            )
185
186
187
def test_backend_definition_fails_if_query_table_missing_columns():
188
    with pytest.raises(ValidationError, match="SQL does not reference columns"):
189
190
        class BackendFixture(SQLBackend):
191
            display_name = "Backend Fixture"
192
            query_engine_class = BaseSQLQueryEngine
193
            patient_join_column = "patient_id"
194
            implements = [Schema]
195
196
            patients = QueryTable(
197
                "SELECT patient_id, not_date_of_birth FROM patients",
198
            )