[e988c2]: / tests / conftest.py

Download this file

373 lines (285 with data), 12.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
import random
import subprocess
import threading
from pathlib import Path
import pytest
from hypothesis.internal.reflection import extract_lambda_source
import ehrql
import ehrql.__main__
from ehrql import query_language as ql
from ehrql.main import get_sql_strings
from ehrql.query_engines.in_memory import InMemoryQueryEngine
from ehrql.query_engines.mssql import MSSQLQueryEngine
from ehrql.query_engines.sqlite import SQLiteQueryEngine
from ehrql.query_engines.trino import TrinoQueryEngine
from ehrql.query_model import nodes as qm
from tests.lib.orm_utils import make_orm_models
from .lib.databases import (
InMemoryPythonDatabase,
InMemorySQLiteDatabase,
make_mssql_database,
make_trino_database,
wait_for_database,
)
from .lib.docker import Containers
def pytest_collection_modifyitems(session, config, items): # pragma: no cover
"""If running with pytest-xdist, add a group identifier to each test item, based on
which database is used by the test.
This lets us use pytest-xdist to distribute tests across three processes leading to
a moderate speed-up, via `pytest -n3`.
The "proper" way to distribute tests with pytest-xdist is by adding the xdist_group
mark. However, this is very hard to do dynamically (because of our use of
request.getfixturevalue) so it is less invasive to add a group identifier here,
during test collection. Later, pytest-xdist will use the group identifier to
distribute tests to workers.
"""
if "PYTEST_XDIST_WORKER" not in os.environ:
# Modifying test item identifiers makes it harder to copy and paste identifiers
# from failing outputs, so it only makes sense to do so if we're running tests
# with pytest-xdist.
return
slow_database_names = ["mssql", "trino"]
for item in items:
group = "other"
if "engine" in item.fixturenames:
database_name = item.callspec.params["engine"]
if database_name in slow_database_names:
group = database_name
else:
found_database_in_fixtures = False
for database_name in slow_database_names:
if any(
database_name in fixture_name for fixture_name in item.fixturenames
):
group = database_name
# Check that tests do not use multiple fixtures for slow databases.
assert not found_database_in_fixtures
found_database_in_fixtures = True
item._nodeid = f"{item.nodeid}@{group}"
# Fail the build if we see any warnings.
def pytest_terminal_summary(terminalreporter, exitstatus, config):
if terminalreporter.stats.get("warnings"): # pragma: no cover
print("ERROR: warnings detected")
if terminalreporter._session.exitstatus == 0:
terminalreporter._session.exitstatus = 13
def pytest_make_parametrize_id(config, val):
# Where we use lambdas as test parameters, having the source as the parameter ID
# makes it quicker to identify specific test cases in the output
if callable(val) and val.__name__ == "<lambda>":
return extract_lambda_source(val).removeprefix("lambda: ")
@pytest.fixture(scope="session")
def show_delayed_warning(request):
"""
Some fixtures can take a long time to execute the first time they're run (e.g. they
might need to pull down a large Docker image) but pytest's output capturing means
that the user has no idea what's happening. This fixture allows us to "poke through"
the output capturing and display a message to the user, but only if the task has
already taken more than N seconds.
"""
def show_warning(message): # pragma: no cover
capturemanager = request.config.pluginmanager.getplugin("capturemanager")
# No need to display anything if output is not being captured
if capturemanager.is_capturing():
with capturemanager.global_and_fixture_disabled():
print(f"\n => {message} ...")
return lambda delay, message: ContextTimer(delay, show_warning, args=[message])
# Timer which starts/cancels itself when entering/exiting a context block
class ContextTimer(threading.Timer):
def __enter__(self):
self.start()
def __exit__(self, *_):
self.cancel()
@pytest.fixture(scope="session")
def containers():
yield Containers()
# Database fixtures {
# These fixtures come in pairs. For each database, there is a session-scoped fixture,
# which performs any setup, and there is a function-scoped fixture, which reuses the
# fixture returned by the session-scoped fixture.
#
# In most cases, we will want the function-scoped fixture, as this allows post-test
# teardown. However, the generative tests require a session-scoped fixture.
@pytest.fixture(scope="session")
def in_memory_sqlite_database_with_session_scope():
return InMemorySQLiteDatabase()
@pytest.fixture(scope="function")
def in_memory_sqlite_database(in_memory_sqlite_database_with_session_scope):
database = in_memory_sqlite_database_with_session_scope
yield database
database.teardown()
@pytest.fixture(scope="session")
def mssql_database_with_session_scope(containers, show_delayed_warning):
with show_delayed_warning(
3, "Starting MSSQL Docker image (will download image on first run)"
):
database = make_mssql_database(containers)
wait_for_database(database)
return database
@pytest.fixture(scope="function")
def mssql_database(mssql_database_with_session_scope):
database = mssql_database_with_session_scope
yield database
database.teardown()
@pytest.fixture(scope="session")
def trino_database_with_session_scope(containers, show_delayed_warning):
with show_delayed_warning(
3, "Starting Trino Docker image (will download image on first run)"
):
database = make_trino_database(containers)
wait_for_database(database)
return database
@pytest.fixture(scope="function")
def trino_database(trino_database_with_session_scope):
database = trino_database_with_session_scope
yield database
database.teardown()
class QueryEngineFixture:
def __init__(self, name, database, query_engine_class):
self.name = name
self.database = database
self.query_engine_class = query_engine_class
def setup(self, *items, metadata=None):
return self.database.setup(*items, metadata=metadata)
def teardown(self):
return self.database.teardown()
def populate(self, *args):
return self.setup(make_orm_models(*args))
def query_engine(self, dsn=False, **engine_kwargs):
if dsn is False:
dsn = self.database.host_url()
return self.query_engine_class(dsn, **engine_kwargs)
def get_results_tables(self, dataset, **engine_kwargs):
if isinstance(dataset, ql.Dataset):
dataset = dataset._compile()
assert isinstance(dataset, qm.Dataset)
query_engine = self.query_engine(**engine_kwargs)
results_tables = query_engine.get_results_tables(dataset)
# We don't explicitly order the results and not all databases naturally
# return in the same order
return [
[row._asdict() for row in sort_table(table)] for table in results_tables
]
def extract(self, dataset, **engine_kwargs):
return self.get_results_tables(dataset, **engine_kwargs)[0]
def dump_dataset_sql(self, dataset, **engine_kwargs):
assert isinstance(dataset, ql.Dataset)
dataset_qm = dataset._compile()
query_engine = self.query_engine(dsn=None, **engine_kwargs)
return get_sql_strings(query_engine, dataset_qm)
def sqlalchemy_engine(self):
return self.query_engine().engine
def sort_table(table):
# Python won't naturally compare None with other values, but we need to sort tables
# containg None values so we treat None as smaller than all other values
return sorted(table, key=sort_key_with_nones)
def sort_key_with_nones(row):
return [(v is not None, v) for v in row]
QUERY_ENGINE_NAMES = ("in_memory", "sqlite", "mssql", "trino")
def engine_factory(request, engine_name, with_session_scope=False):
if engine_name == "in_memory":
return QueryEngineFixture(
engine_name, InMemoryPythonDatabase(), InMemoryQueryEngine
)
if engine_name == "sqlite":
database_fixture_name = "in_memory_sqlite_database"
query_engine_class = SQLiteQueryEngine
elif engine_name == "mssql":
database_fixture_name = "mssql_database"
query_engine_class = MSSQLQueryEngine
elif engine_name == "trino":
database_fixture_name = "trino_database"
query_engine_class = TrinoQueryEngine
else:
assert False
if with_session_scope:
database_fixture_name = f"{database_fixture_name}_with_session_scope"
# We dynamically request fixtures rather than making them arguments in the usual way
# so that we only start the database containers we actually need for the test run
database = request.getfixturevalue(database_fixture_name)
return QueryEngineFixture(engine_name, database, query_engine_class)
@pytest.fixture(params=QUERY_ENGINE_NAMES)
def engine(request):
return engine_factory(request, request.param)
@pytest.fixture
def mssql_engine(request):
return engine_factory(request, "mssql")
@pytest.fixture
def trino_engine(request):
return engine_factory(request, "trino")
@pytest.fixture
def in_memory_engine(request):
return engine_factory(request, "in_memory")
@pytest.fixture
def sqlite_engine(request):
return engine_factory(request, "sqlite")
@pytest.fixture(scope="session")
def ehrql_image(show_delayed_warning):
project_dir = Path(ehrql.__file__).parents[1]
# Note different name from production image to avoid confusion
image = "ehrql-dev"
# We're deliberately choosing to shell out to the docker client here rather than use
# the docker-py library to avoid possible difference in the build process (docker-py
# doesn't seem to be particularly actively maintained)
with show_delayed_warning(3, f"Building {image} Docker image"):
subprocess.run(
["docker", "build", project_dir, "-t", image],
check=True,
env=dict(os.environ, DOCKER_BUILDKIT="1"),
)
return f"{image}:latest"
@pytest.fixture(autouse=True)
def random_should_not_be_used():
"""Asserts that every test should leave the global random number generator unchanged.
This is because we want all of our use of randomness to be based on seeded random number
generators, so it should not depend on the global random number generator in any way.
If this is failing, please find all uses of the random module and replace them with a
Random instance, ideally from a predictable seed.
"""
prev_state = random.getstate()
yield
assert random.getstate() == prev_state, (
"Global random number generator was used in test."
)
@pytest.fixture
def call_cli(capsys):
"""
Wrapper around the CLI entrypoint to make it easier to call from tests
"""
def call(*args, environ=None):
# Convert any Path instances to strings
args = [str(arg) if isinstance(arg, Path) else arg for arg in args]
ehrql.__main__.main(args, environ=environ)
return capsys.readouterr()
# Allow reading captured output even when call throws an exception
call.readouterr = capsys.readouterr
return call
@pytest.fixture
def call_cli_docker(containers, ehrql_image):
"""
As above, but invoke the CLI via the Docker image
"""
def call(*args, environ=None, workspace=None):
args = [
# Make any paths relative to the workspace directory so they still point to
# the right place inside Docker. If you supply path arguments and no
# workspace this will error, as it should. Likewise if you supply paths
# outside of the workspace.
str(arg.relative_to(workspace)) if isinstance(arg, Path) else str(arg)
for arg in args
]
if workspace is not None:
# Because the files in these directories will need to be readable by
# low-privilege, isolated processes we can't use the standard restrictive
# permissions for temporary directories
workspace.chmod(0o755)
volumes = {workspace: {"bind": "/workspace", "mode": "rw"}}
else:
volumes = {}
return containers.run_captured(
ehrql_image,
command=args,
volumes=volumes,
environment=environ or {},
)
return call