Diff of /tests/lib/databases.py [000000] .. [e988c2]

Switch to unified view

a b/tests/lib/databases.py
1
import secrets
2
import time
3
from pathlib import Path
4
5
import sqlalchemy
6
import sqlalchemy.exc
7
from packaging.version import parse as version_parse
8
from requests.exceptions import ConnectionError  # noqa A004
9
from sqlalchemy.dialects import registry
10
from sqlalchemy.orm import sessionmaker
11
from trino.exceptions import TrinoQueryError
12
13
from ehrql.query_engines.in_memory_database import InMemoryDatabase
14
from ehrql.utils.itertools_utils import iter_flatten
15
from tests.lib.orm_utils import SYNTHETIC_PRIMARY_KEY, table_has_one_row_per_patient
16
17
18
MSSQL_SETUP_DIR = Path(__file__).parents[1].absolute() / "support/mssql"
19
TRINO_SETUP_DIR = Path(__file__).parents[1].absolute() / "support/trino"
20
21
22
# Register our modified SQLAlchemy dialects
23
registry.register(
24
    "sqlite.pysqlite.opensafely",
25
    "ehrql.query_engines.sqlite_dialect",
26
    "SQLiteDialect",
27
)
28
29
registry.register(
30
    "trino.opensafely", "ehrql.query_engines.trino_dialect", "TrinoDialect"
31
)
32
33
34
class DbDetails:
35
    def __init__(
36
        self,
37
        protocol,
38
        driver,
39
        host_from_container,
40
        port_from_container,
41
        host_from_host,
42
        port_from_host,
43
        username="",
44
        password="",
45
        db_name="",
46
        query=None,
47
        temp_db=None,
48
        engine_kwargs=None,
49
    ):
50
        self.protocol = protocol
51
        self.driver = driver
52
        self.host_from_container = host_from_container
53
        self.port_from_container = port_from_container
54
        self.host_from_host = host_from_host
55
        self.port_from_host = port_from_host
56
        self.password = password
57
        self.username = username
58
        self.db_name = db_name
59
        self.query = query
60
        self.temp_db = temp_db
61
        self.engine_kwargs = engine_kwargs or {}
62
        self.metadata = None
63
64
    def container_url(self):
65
        return self._url(self.host_from_container, self.port_from_container)
66
67
    def host_url(self):
68
        return self._url(self.host_from_host, self.port_from_host)
69
70
    def engine(self, dialect=None, **kwargs):
71
        url = self._url(
72
            self.host_from_host, self.port_from_host, include_driver=bool(self.driver)
73
        )
74
        engine_url = sqlalchemy.engine.make_url(url)
75
        engine_kwargs = self.engine_kwargs | kwargs
76
        engine = sqlalchemy.create_engine(engine_url, **engine_kwargs)
77
        return engine
78
79
    def _url(self, host, port, include_driver=False):
80
        assert self.username
81
        if self.username and self.password:
82
            auth = f"{self.username}:{self.password}@"
83
        else:
84
            auth = f"{self.username}@"
85
        if include_driver:
86
            protocol = f"{self.protocol}+{self.driver}"
87
        else:
88
            protocol = self.protocol
89
        url = f"{protocol}://{auth}{host}:{port}/{self.db_name}"
90
        return url
91
92
    def setup(self, *input_data, metadata=None):
93
        """
94
        Accepts SQLAlchemy ORM objects (which may be arbitrarily nested within lists and
95
        tuples), creates the necessary tables and inserts them into the database
96
        """
97
        input_data = list(iter_flatten(input_data))
98
        engine = self.engine()
99
        Session = sessionmaker()
100
        Session.configure(bind=engine)
101
        session = Session()
102
103
        if metadata:
104
            pass
105
        elif input_data:
106
            metadata = input_data[0].metadata
107
        else:
108
            assert False, "No source of metadata"
109
        assert all(item.metadata is metadata for item in input_data)
110
111
        self.metadata = metadata
112
        metadata.create_all(engine)
113
        session.bulk_save_objects(
114
            input_data,
115
            return_defaults=False,
116
            update_changed_only=False,
117
            preserve_order=False,
118
        )
119
        session.commit()
120
121
    def teardown(self):
122
        if self.metadata is not None:
123
            self.metadata.drop_all(self.engine())
124
125
126
def wait_for_database(database, timeout=20):
127
    engine = database.engine()
128
    start = time.time()
129
    limit = start + timeout
130
    while True:
131
        try:
132
            with engine.connect() as connection:
133
                connection.execute(sqlalchemy.text("SELECT 'hello'"))
134
            break
135
        except (
136
            sqlalchemy.exc.OperationalError,
137
            ConnectionRefusedError,
138
            ConnectionResetError,
139
            BrokenPipeError,
140
            ConnectionError,
141
            TrinoQueryError,
142
            sqlalchemy.exc.DBAPIError,
143
        ) as e:  # pragma: no cover
144
            if time.time() >= limit:
145
                raise Exception(
146
                    f"Failed to connect to database after {timeout} seconds: "
147
                    f"{engine.url}"
148
                ) from e
149
            time.sleep(1)
150
151
152
def make_mssql_database(containers):
153
    password = "Your_password123!"
154
155
    container_name = "ehrql-mssql"
156
    mssql_port = 1433
157
158
    if not containers.is_running(container_name):  # pragma: no cover
159
        run_mssql(container_name, containers, password, mssql_port)
160
161
    container_ip = containers.get_container_ip(container_name)
162
    host_mssql_port = containers.get_mapped_port_for_host(container_name, mssql_port)
163
164
    return DbDetails(
165
        protocol="mssql",
166
        driver="pymssql",
167
        host_from_container=container_ip,
168
        port_from_container=mssql_port,
169
        host_from_host="localhost",
170
        port_from_host=host_mssql_port,
171
        username="sa",
172
        password=password,
173
        db_name="test",
174
    )
175
176
177
def run_mssql(container_name, containers, password, mssql_port):  # pragma: no cover
178
    containers.run_bg(
179
        name=container_name,
180
        # This is *not* the version that TPP run for us in production which, as of
181
        # 2024-09-24, is SQL Server 2016 (13.0.5893.48). That version is not available
182
        # as a Docker image, so we run the oldest supported version instead. Both the
183
        # production server and our test server set the "compatibility level" to the
184
        # same value so the same feature set should be supported.
185
        image="mcr.microsoft.com/mssql/server:2019-CU28-ubuntu-20.04",
186
        volumes={
187
            MSSQL_SETUP_DIR: {"bind": "/mssql", "mode": "ro"},
188
        },
189
        # Choose an arbitrary free port to publish the MSSQL port on
190
        ports={mssql_port: None},
191
        environment={
192
            "MSSQL_SA_PASSWORD": password,
193
            "ACCEPT_EULA": "Y",
194
            "MSSQL_TCP_PORT": str(mssql_port),
195
            # Make all string comparisons case-sensitive across all databases
196
            "MSSQL_COLLATION": "SQL_Latin1_General_CP1_CS_AS",
197
        },
198
        user="root",
199
        entrypoint="/mssql/entrypoint.sh",
200
        command="/opt/mssql/bin/sqlservr",
201
    )
202
203
204
class InMemorySQLiteDatabase(DbDetails):
205
    def __init__(self):
206
        db_name = secrets.token_hex(8)
207
        super().__init__(
208
            db_name=db_name,
209
            protocol="sqlite",
210
            driver="pysqlite+opensafely",
211
            host_from_container=None,
212
            port_from_container=None,
213
            host_from_host=None,
214
            port_from_host=None,
215
        )
216
        self._engine = None
217
218
    def engine(self, dialect=None, **kwargs):
219
        # We need to hold a reference to the engine for the lifetime of this database to stop the contents of the
220
        # database from being garbage-collected.
221
        if not self._engine:
222
            self._engine = super().engine(dialect, **kwargs)
223
        return self._engine
224
225
    def _url(self, host, port, include_driver=False):
226
        if include_driver:
227
            protocol = f"{self.protocol}+{self.driver}"
228
        else:
229
            protocol = self.protocol
230
        # https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#uri-connections
231
        # https://sqlite.org/inmemorydb.html
232
        return f"{protocol}:///file:{self.db_name}?mode=memory&cache=shared&uri=true"
233
234
235
class InMemoryPythonDatabase:
236
    def __init__(self):
237
        self.database = InMemoryDatabase()
238
239
    def setup(self, *input_data, metadata=None):
240
        """
241
        Behaves like `DbDetails.setup` in taking a iterator of ORM instances but
242
        translates these into the sort of objects needed by the `InMemoryDatabase`
243
        """
244
        input_data = list(iter_flatten(input_data))
245
246
        if metadata:
247
            pass
248
        elif input_data:
249
            metadata = input_data[0].metadata
250
        else:
251
            assert False, "No source of metadata"
252
        assert all(item.metadata is metadata for item in input_data)
253
254
        sqla_table_to_items = {table: [] for table in metadata.sorted_tables}
255
        for item in input_data:
256
            sqla_table_to_items[item.__table__].append(item)
257
258
        for sqla_table, items in sqla_table_to_items.items():
259
            columns = [
260
                c.name for c in sqla_table.columns if c.name != SYNTHETIC_PRIMARY_KEY
261
            ]
262
            self.database.add_table(
263
                name=sqla_table.name,
264
                one_row_per_patient=table_has_one_row_per_patient(sqla_table),
265
                columns=columns,
266
                rows=[[getattr(item, c) for c in columns] for item in items],
267
            )
268
269
    def teardown(self):
270
        self.database.populate({})
271
272
    def host_url(self):
273
        # Where other query engines expect a DSN string to connect to the database the
274
        # InMemoryQueryEngine expects a reference to the database object itself
275
        return self.database
276
277
278
def make_trino_database(containers):
279
    container_name = "ehrql-trino"
280
    trino_port = 8080
281
282
    if not containers.is_running(container_name):  # pragma: no cover
283
        run_trino(container_name, containers, trino_port)
284
285
    container_ip = containers.get_container_ip(container_name)
286
    host_trino_port = containers.get_mapped_port_for_host(container_name, trino_port)
287
288
    return DbDetails(
289
        protocol="trino",
290
        driver="opensafely",
291
        host_from_container=container_ip,
292
        port_from_container=trino_port,
293
        host_from_host="localhost",
294
        port_from_host=host_trino_port,
295
        username="trino",
296
        db_name="trino/default",
297
        # Disable automatic retries for the test client: it's pointless and creates log
298
        # noise
299
        engine_kwargs={"connect_args": {"max_attempts": 1}},
300
    )
301
302
303
def run_trino(container_name, containers, trino_port):  # pragma: no cover
304
    # Note, I don't actually know that this is the minimum required version of Docker
305
    # Engine. I do know that 20.10.5 is unsupported (because that's what I had
306
    # installed) and that 20.10.16 is supported, according to this comment:
307
    # https://github.com/adoptium/containers/issues/214#issuecomment-1139464798 which
308
    # was linked from this issue: https://github.com/trinodb/trino/issues/14269
309
    min_docker_version = "20.10.16"
310
    docker_version = containers.get_engine_version()
311
    assert version_parse(docker_version) >= version_parse(min_docker_version), (
312
        f"The Trino Docker image requires Docker Engine v{min_docker_version}"
313
        f" or above but you have v{docker_version}"
314
    )
315
    containers.run_bg(
316
        name=container_name,
317
        # This is the version which happened to be current at the time of writing and is
318
        # pinned for reproduciblity's sake rather than because there's anything
319
        # significant about it
320
        image="trinodb/trino:440",
321
        volumes={
322
            TRINO_SETUP_DIR: {"bind": "/trino", "mode": "ro"},
323
            f"{TRINO_SETUP_DIR}/etc": {"bind": "/etc/trino", "mode": "ro"},
324
        },
325
        # Choose an arbitrary free port to publish the trino port on
326
        ports={trino_port: None},
327
        environment={},
328
        user="root",
329
        entrypoint="/trino/entrypoint.sh",
330
        command="/usr/lib/trino/bin/run-trino",
331
    )