|
a |
|
b/tests/lib/databases.py |
|
|
1 |
import secrets |
|
|
2 |
import time |
|
|
3 |
from pathlib import Path |
|
|
4 |
|
|
|
5 |
import sqlalchemy |
|
|
6 |
import sqlalchemy.exc |
|
|
7 |
from packaging.version import parse as version_parse |
|
|
8 |
from requests.exceptions import ConnectionError # noqa A004 |
|
|
9 |
from sqlalchemy.dialects import registry |
|
|
10 |
from sqlalchemy.orm import sessionmaker |
|
|
11 |
from trino.exceptions import TrinoQueryError |
|
|
12 |
|
|
|
13 |
from ehrql.query_engines.in_memory_database import InMemoryDatabase |
|
|
14 |
from ehrql.utils.itertools_utils import iter_flatten |
|
|
15 |
from tests.lib.orm_utils import SYNTHETIC_PRIMARY_KEY, table_has_one_row_per_patient |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
MSSQL_SETUP_DIR = Path(__file__).parents[1].absolute() / "support/mssql" |
|
|
19 |
TRINO_SETUP_DIR = Path(__file__).parents[1].absolute() / "support/trino" |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
# Register our modified SQLAlchemy dialects |
|
|
23 |
registry.register( |
|
|
24 |
"sqlite.pysqlite.opensafely", |
|
|
25 |
"ehrql.query_engines.sqlite_dialect", |
|
|
26 |
"SQLiteDialect", |
|
|
27 |
) |
|
|
28 |
|
|
|
29 |
registry.register( |
|
|
30 |
"trino.opensafely", "ehrql.query_engines.trino_dialect", "TrinoDialect" |
|
|
31 |
) |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
class DbDetails: |
|
|
35 |
def __init__( |
|
|
36 |
self, |
|
|
37 |
protocol, |
|
|
38 |
driver, |
|
|
39 |
host_from_container, |
|
|
40 |
port_from_container, |
|
|
41 |
host_from_host, |
|
|
42 |
port_from_host, |
|
|
43 |
username="", |
|
|
44 |
password="", |
|
|
45 |
db_name="", |
|
|
46 |
query=None, |
|
|
47 |
temp_db=None, |
|
|
48 |
engine_kwargs=None, |
|
|
49 |
): |
|
|
50 |
self.protocol = protocol |
|
|
51 |
self.driver = driver |
|
|
52 |
self.host_from_container = host_from_container |
|
|
53 |
self.port_from_container = port_from_container |
|
|
54 |
self.host_from_host = host_from_host |
|
|
55 |
self.port_from_host = port_from_host |
|
|
56 |
self.password = password |
|
|
57 |
self.username = username |
|
|
58 |
self.db_name = db_name |
|
|
59 |
self.query = query |
|
|
60 |
self.temp_db = temp_db |
|
|
61 |
self.engine_kwargs = engine_kwargs or {} |
|
|
62 |
self.metadata = None |
|
|
63 |
|
|
|
64 |
def container_url(self): |
|
|
65 |
return self._url(self.host_from_container, self.port_from_container) |
|
|
66 |
|
|
|
67 |
def host_url(self): |
|
|
68 |
return self._url(self.host_from_host, self.port_from_host) |
|
|
69 |
|
|
|
70 |
def engine(self, dialect=None, **kwargs): |
|
|
71 |
url = self._url( |
|
|
72 |
self.host_from_host, self.port_from_host, include_driver=bool(self.driver) |
|
|
73 |
) |
|
|
74 |
engine_url = sqlalchemy.engine.make_url(url) |
|
|
75 |
engine_kwargs = self.engine_kwargs | kwargs |
|
|
76 |
engine = sqlalchemy.create_engine(engine_url, **engine_kwargs) |
|
|
77 |
return engine |
|
|
78 |
|
|
|
79 |
def _url(self, host, port, include_driver=False): |
|
|
80 |
assert self.username |
|
|
81 |
if self.username and self.password: |
|
|
82 |
auth = f"{self.username}:{self.password}@" |
|
|
83 |
else: |
|
|
84 |
auth = f"{self.username}@" |
|
|
85 |
if include_driver: |
|
|
86 |
protocol = f"{self.protocol}+{self.driver}" |
|
|
87 |
else: |
|
|
88 |
protocol = self.protocol |
|
|
89 |
url = f"{protocol}://{auth}{host}:{port}/{self.db_name}" |
|
|
90 |
return url |
|
|
91 |
|
|
|
92 |
def setup(self, *input_data, metadata=None): |
|
|
93 |
""" |
|
|
94 |
Accepts SQLAlchemy ORM objects (which may be arbitrarily nested within lists and |
|
|
95 |
tuples), creates the necessary tables and inserts them into the database |
|
|
96 |
""" |
|
|
97 |
input_data = list(iter_flatten(input_data)) |
|
|
98 |
engine = self.engine() |
|
|
99 |
Session = sessionmaker() |
|
|
100 |
Session.configure(bind=engine) |
|
|
101 |
session = Session() |
|
|
102 |
|
|
|
103 |
if metadata: |
|
|
104 |
pass |
|
|
105 |
elif input_data: |
|
|
106 |
metadata = input_data[0].metadata |
|
|
107 |
else: |
|
|
108 |
assert False, "No source of metadata" |
|
|
109 |
assert all(item.metadata is metadata for item in input_data) |
|
|
110 |
|
|
|
111 |
self.metadata = metadata |
|
|
112 |
metadata.create_all(engine) |
|
|
113 |
session.bulk_save_objects( |
|
|
114 |
input_data, |
|
|
115 |
return_defaults=False, |
|
|
116 |
update_changed_only=False, |
|
|
117 |
preserve_order=False, |
|
|
118 |
) |
|
|
119 |
session.commit() |
|
|
120 |
|
|
|
121 |
def teardown(self): |
|
|
122 |
if self.metadata is not None: |
|
|
123 |
self.metadata.drop_all(self.engine()) |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def wait_for_database(database, timeout=20): |
|
|
127 |
engine = database.engine() |
|
|
128 |
start = time.time() |
|
|
129 |
limit = start + timeout |
|
|
130 |
while True: |
|
|
131 |
try: |
|
|
132 |
with engine.connect() as connection: |
|
|
133 |
connection.execute(sqlalchemy.text("SELECT 'hello'")) |
|
|
134 |
break |
|
|
135 |
except ( |
|
|
136 |
sqlalchemy.exc.OperationalError, |
|
|
137 |
ConnectionRefusedError, |
|
|
138 |
ConnectionResetError, |
|
|
139 |
BrokenPipeError, |
|
|
140 |
ConnectionError, |
|
|
141 |
TrinoQueryError, |
|
|
142 |
sqlalchemy.exc.DBAPIError, |
|
|
143 |
) as e: # pragma: no cover |
|
|
144 |
if time.time() >= limit: |
|
|
145 |
raise Exception( |
|
|
146 |
f"Failed to connect to database after {timeout} seconds: " |
|
|
147 |
f"{engine.url}" |
|
|
148 |
) from e |
|
|
149 |
time.sleep(1) |
|
|
150 |
|
|
|
151 |
|
|
|
152 |
def make_mssql_database(containers): |
|
|
153 |
password = "Your_password123!" |
|
|
154 |
|
|
|
155 |
container_name = "ehrql-mssql" |
|
|
156 |
mssql_port = 1433 |
|
|
157 |
|
|
|
158 |
if not containers.is_running(container_name): # pragma: no cover |
|
|
159 |
run_mssql(container_name, containers, password, mssql_port) |
|
|
160 |
|
|
|
161 |
container_ip = containers.get_container_ip(container_name) |
|
|
162 |
host_mssql_port = containers.get_mapped_port_for_host(container_name, mssql_port) |
|
|
163 |
|
|
|
164 |
return DbDetails( |
|
|
165 |
protocol="mssql", |
|
|
166 |
driver="pymssql", |
|
|
167 |
host_from_container=container_ip, |
|
|
168 |
port_from_container=mssql_port, |
|
|
169 |
host_from_host="localhost", |
|
|
170 |
port_from_host=host_mssql_port, |
|
|
171 |
username="sa", |
|
|
172 |
password=password, |
|
|
173 |
db_name="test", |
|
|
174 |
) |
|
|
175 |
|
|
|
176 |
|
|
|
177 |
def run_mssql(container_name, containers, password, mssql_port): # pragma: no cover |
|
|
178 |
containers.run_bg( |
|
|
179 |
name=container_name, |
|
|
180 |
# This is *not* the version that TPP run for us in production which, as of |
|
|
181 |
# 2024-09-24, is SQL Server 2016 (13.0.5893.48). That version is not available |
|
|
182 |
# as a Docker image, so we run the oldest supported version instead. Both the |
|
|
183 |
# production server and our test server set the "compatibility level" to the |
|
|
184 |
# same value so the same feature set should be supported. |
|
|
185 |
image="mcr.microsoft.com/mssql/server:2019-CU28-ubuntu-20.04", |
|
|
186 |
volumes={ |
|
|
187 |
MSSQL_SETUP_DIR: {"bind": "/mssql", "mode": "ro"}, |
|
|
188 |
}, |
|
|
189 |
# Choose an arbitrary free port to publish the MSSQL port on |
|
|
190 |
ports={mssql_port: None}, |
|
|
191 |
environment={ |
|
|
192 |
"MSSQL_SA_PASSWORD": password, |
|
|
193 |
"ACCEPT_EULA": "Y", |
|
|
194 |
"MSSQL_TCP_PORT": str(mssql_port), |
|
|
195 |
# Make all string comparisons case-sensitive across all databases |
|
|
196 |
"MSSQL_COLLATION": "SQL_Latin1_General_CP1_CS_AS", |
|
|
197 |
}, |
|
|
198 |
user="root", |
|
|
199 |
entrypoint="/mssql/entrypoint.sh", |
|
|
200 |
command="/opt/mssql/bin/sqlservr", |
|
|
201 |
) |
|
|
202 |
|
|
|
203 |
|
|
|
204 |
class InMemorySQLiteDatabase(DbDetails): |
|
|
205 |
def __init__(self): |
|
|
206 |
db_name = secrets.token_hex(8) |
|
|
207 |
super().__init__( |
|
|
208 |
db_name=db_name, |
|
|
209 |
protocol="sqlite", |
|
|
210 |
driver="pysqlite+opensafely", |
|
|
211 |
host_from_container=None, |
|
|
212 |
port_from_container=None, |
|
|
213 |
host_from_host=None, |
|
|
214 |
port_from_host=None, |
|
|
215 |
) |
|
|
216 |
self._engine = None |
|
|
217 |
|
|
|
218 |
def engine(self, dialect=None, **kwargs): |
|
|
219 |
# We need to hold a reference to the engine for the lifetime of this database to stop the contents of the |
|
|
220 |
# database from being garbage-collected. |
|
|
221 |
if not self._engine: |
|
|
222 |
self._engine = super().engine(dialect, **kwargs) |
|
|
223 |
return self._engine |
|
|
224 |
|
|
|
225 |
def _url(self, host, port, include_driver=False): |
|
|
226 |
if include_driver: |
|
|
227 |
protocol = f"{self.protocol}+{self.driver}" |
|
|
228 |
else: |
|
|
229 |
protocol = self.protocol |
|
|
230 |
# https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#uri-connections |
|
|
231 |
# https://sqlite.org/inmemorydb.html |
|
|
232 |
return f"{protocol}:///file:{self.db_name}?mode=memory&cache=shared&uri=true" |
|
|
233 |
|
|
|
234 |
|
|
|
235 |
class InMemoryPythonDatabase: |
|
|
236 |
def __init__(self): |
|
|
237 |
self.database = InMemoryDatabase() |
|
|
238 |
|
|
|
239 |
def setup(self, *input_data, metadata=None): |
|
|
240 |
""" |
|
|
241 |
Behaves like `DbDetails.setup` in taking a iterator of ORM instances but |
|
|
242 |
translates these into the sort of objects needed by the `InMemoryDatabase` |
|
|
243 |
""" |
|
|
244 |
input_data = list(iter_flatten(input_data)) |
|
|
245 |
|
|
|
246 |
if metadata: |
|
|
247 |
pass |
|
|
248 |
elif input_data: |
|
|
249 |
metadata = input_data[0].metadata |
|
|
250 |
else: |
|
|
251 |
assert False, "No source of metadata" |
|
|
252 |
assert all(item.metadata is metadata for item in input_data) |
|
|
253 |
|
|
|
254 |
sqla_table_to_items = {table: [] for table in metadata.sorted_tables} |
|
|
255 |
for item in input_data: |
|
|
256 |
sqla_table_to_items[item.__table__].append(item) |
|
|
257 |
|
|
|
258 |
for sqla_table, items in sqla_table_to_items.items(): |
|
|
259 |
columns = [ |
|
|
260 |
c.name for c in sqla_table.columns if c.name != SYNTHETIC_PRIMARY_KEY |
|
|
261 |
] |
|
|
262 |
self.database.add_table( |
|
|
263 |
name=sqla_table.name, |
|
|
264 |
one_row_per_patient=table_has_one_row_per_patient(sqla_table), |
|
|
265 |
columns=columns, |
|
|
266 |
rows=[[getattr(item, c) for c in columns] for item in items], |
|
|
267 |
) |
|
|
268 |
|
|
|
269 |
def teardown(self): |
|
|
270 |
self.database.populate({}) |
|
|
271 |
|
|
|
272 |
def host_url(self): |
|
|
273 |
# Where other query engines expect a DSN string to connect to the database the |
|
|
274 |
# InMemoryQueryEngine expects a reference to the database object itself |
|
|
275 |
return self.database |
|
|
276 |
|
|
|
277 |
|
|
|
278 |
def make_trino_database(containers): |
|
|
279 |
container_name = "ehrql-trino" |
|
|
280 |
trino_port = 8080 |
|
|
281 |
|
|
|
282 |
if not containers.is_running(container_name): # pragma: no cover |
|
|
283 |
run_trino(container_name, containers, trino_port) |
|
|
284 |
|
|
|
285 |
container_ip = containers.get_container_ip(container_name) |
|
|
286 |
host_trino_port = containers.get_mapped_port_for_host(container_name, trino_port) |
|
|
287 |
|
|
|
288 |
return DbDetails( |
|
|
289 |
protocol="trino", |
|
|
290 |
driver="opensafely", |
|
|
291 |
host_from_container=container_ip, |
|
|
292 |
port_from_container=trino_port, |
|
|
293 |
host_from_host="localhost", |
|
|
294 |
port_from_host=host_trino_port, |
|
|
295 |
username="trino", |
|
|
296 |
db_name="trino/default", |
|
|
297 |
# Disable automatic retries for the test client: it's pointless and creates log |
|
|
298 |
# noise |
|
|
299 |
engine_kwargs={"connect_args": {"max_attempts": 1}}, |
|
|
300 |
) |
|
|
301 |
|
|
|
302 |
|
|
|
303 |
def run_trino(container_name, containers, trino_port): # pragma: no cover |
|
|
304 |
# Note, I don't actually know that this is the minimum required version of Docker |
|
|
305 |
# Engine. I do know that 20.10.5 is unsupported (because that's what I had |
|
|
306 |
# installed) and that 20.10.16 is supported, according to this comment: |
|
|
307 |
# https://github.com/adoptium/containers/issues/214#issuecomment-1139464798 which |
|
|
308 |
# was linked from this issue: https://github.com/trinodb/trino/issues/14269 |
|
|
309 |
min_docker_version = "20.10.16" |
|
|
310 |
docker_version = containers.get_engine_version() |
|
|
311 |
assert version_parse(docker_version) >= version_parse(min_docker_version), ( |
|
|
312 |
f"The Trino Docker image requires Docker Engine v{min_docker_version}" |
|
|
313 |
f" or above but you have v{docker_version}" |
|
|
314 |
) |
|
|
315 |
containers.run_bg( |
|
|
316 |
name=container_name, |
|
|
317 |
# This is the version which happened to be current at the time of writing and is |
|
|
318 |
# pinned for reproduciblity's sake rather than because there's anything |
|
|
319 |
# significant about it |
|
|
320 |
image="trinodb/trino:440", |
|
|
321 |
volumes={ |
|
|
322 |
TRINO_SETUP_DIR: {"bind": "/trino", "mode": "ro"}, |
|
|
323 |
f"{TRINO_SETUP_DIR}/etc": {"bind": "/etc/trino", "mode": "ro"}, |
|
|
324 |
}, |
|
|
325 |
# Choose an arbitrary free port to publish the trino port on |
|
|
326 |
ports={trino_port: None}, |
|
|
327 |
environment={}, |
|
|
328 |
user="root", |
|
|
329 |
entrypoint="/trino/entrypoint.sh", |
|
|
330 |
command="/usr/lib/trino/bin/run-trino", |
|
|
331 |
) |