Diff of /tests/conftest.py [000000] .. [e988c2]

Switch to unified view

a b/tests/conftest.py
1
import os
2
import random
3
import subprocess
4
import threading
5
from pathlib import Path
6
7
import pytest
8
from hypothesis.internal.reflection import extract_lambda_source
9
10
import ehrql
11
import ehrql.__main__
12
from ehrql import query_language as ql
13
from ehrql.main import get_sql_strings
14
from ehrql.query_engines.in_memory import InMemoryQueryEngine
15
from ehrql.query_engines.mssql import MSSQLQueryEngine
16
from ehrql.query_engines.sqlite import SQLiteQueryEngine
17
from ehrql.query_engines.trino import TrinoQueryEngine
18
from ehrql.query_model import nodes as qm
19
from tests.lib.orm_utils import make_orm_models
20
21
from .lib.databases import (
22
    InMemoryPythonDatabase,
23
    InMemorySQLiteDatabase,
24
    make_mssql_database,
25
    make_trino_database,
26
    wait_for_database,
27
)
28
from .lib.docker import Containers
29
30
31
def pytest_collection_modifyitems(session, config, items):  # pragma: no cover
32
    """If running with pytest-xdist, add a group identifier to each test item, based on
33
    which database is used by the test.
34
35
    This lets us use pytest-xdist to distribute tests across three processes leading to
36
    a moderate speed-up, via `pytest -n3`.
37
38
    The "proper" way to distribute tests with pytest-xdist is by adding the xdist_group
39
    mark.  However, this is very hard to do dynamically (because of our use of
40
    request.getfixturevalue) so it is less invasive to add a group identifier here,
41
    during test collection.  Later, pytest-xdist will use the group identifier to
42
    distribute tests to workers.
43
    """
44
45
    if "PYTEST_XDIST_WORKER" not in os.environ:
46
        # Modifying test item identifiers makes it harder to copy and paste identifiers
47
        # from failing outputs, so it only makes sense to do so if we're running tests
48
        # with pytest-xdist.
49
        return
50
51
    slow_database_names = ["mssql", "trino"]
52
53
    for item in items:
54
        group = "other"
55
56
        if "engine" in item.fixturenames:
57
            database_name = item.callspec.params["engine"]
58
            if database_name in slow_database_names:
59
                group = database_name
60
61
        else:
62
            found_database_in_fixtures = False
63
            for database_name in slow_database_names:
64
                if any(
65
                    database_name in fixture_name for fixture_name in item.fixturenames
66
                ):
67
                    group = database_name
68
                    # Check that tests do not use multiple fixtures for slow databases.
69
                    assert not found_database_in_fixtures
70
                    found_database_in_fixtures = True
71
72
        item._nodeid = f"{item.nodeid}@{group}"
73
74
75
# Fail the build if we see any warnings.
76
def pytest_terminal_summary(terminalreporter, exitstatus, config):
77
    if terminalreporter.stats.get("warnings"):  # pragma: no cover
78
        print("ERROR: warnings detected")
79
        if terminalreporter._session.exitstatus == 0:
80
            terminalreporter._session.exitstatus = 13
81
82
83
def pytest_make_parametrize_id(config, val):
84
    # Where we use lambdas as test parameters, having the source as the parameter ID
85
    # makes it quicker to identify specific test cases in the output
86
    if callable(val) and val.__name__ == "<lambda>":
87
        return extract_lambda_source(val).removeprefix("lambda: ")
88
89
90
@pytest.fixture(scope="session")
91
def show_delayed_warning(request):
92
    """
93
    Some fixtures can take a long time to execute the first time they're run (e.g. they
94
    might need to pull down a large Docker image) but pytest's output capturing means
95
    that the user has no idea what's happening. This fixture allows us to "poke through"
96
    the output capturing and display a message to the user, but only if the task has
97
    already taken more than N seconds.
98
    """
99
100
    def show_warning(message):  # pragma: no cover
101
        capturemanager = request.config.pluginmanager.getplugin("capturemanager")
102
        # No need to display anything if output is not being captured
103
        if capturemanager.is_capturing():
104
            with capturemanager.global_and_fixture_disabled():
105
                print(f"\n => {message} ...")
106
107
    return lambda delay, message: ContextTimer(delay, show_warning, args=[message])
108
109
110
# Timer which starts/cancels itself when entering/exiting a context block
111
class ContextTimer(threading.Timer):
112
    def __enter__(self):
113
        self.start()
114
115
    def __exit__(self, *_):
116
        self.cancel()
117
118
119
@pytest.fixture(scope="session")
120
def containers():
121
    yield Containers()
122
123
124
# Database fixtures {
125
126
# These fixtures come in pairs.  For each database, there is a session-scoped fixture,
127
# which performs any setup, and there is a function-scoped fixture, which reuses the
128
# fixture returned by the session-scoped fixture.
129
#
130
# In most cases, we will want the function-scoped fixture, as this allows post-test
131
# teardown.  However, the generative tests require a session-scoped fixture.
132
133
134
@pytest.fixture(scope="session")
135
def in_memory_sqlite_database_with_session_scope():
136
    return InMemorySQLiteDatabase()
137
138
139
@pytest.fixture(scope="function")
140
def in_memory_sqlite_database(in_memory_sqlite_database_with_session_scope):
141
    database = in_memory_sqlite_database_with_session_scope
142
    yield database
143
    database.teardown()
144
145
146
@pytest.fixture(scope="session")
147
def mssql_database_with_session_scope(containers, show_delayed_warning):
148
    with show_delayed_warning(
149
        3, "Starting MSSQL Docker image (will download image on first run)"
150
    ):
151
        database = make_mssql_database(containers)
152
        wait_for_database(database)
153
    return database
154
155
156
@pytest.fixture(scope="function")
157
def mssql_database(mssql_database_with_session_scope):
158
    database = mssql_database_with_session_scope
159
    yield database
160
    database.teardown()
161
162
163
@pytest.fixture(scope="session")
164
def trino_database_with_session_scope(containers, show_delayed_warning):
165
    with show_delayed_warning(
166
        3, "Starting Trino Docker image (will download image on first run)"
167
    ):
168
        database = make_trino_database(containers)
169
        wait_for_database(database)
170
    return database
171
172
173
@pytest.fixture(scope="function")
174
def trino_database(trino_database_with_session_scope):
175
    database = trino_database_with_session_scope
176
    yield database
177
    database.teardown()
178
179
180
class QueryEngineFixture:
181
    def __init__(self, name, database, query_engine_class):
182
        self.name = name
183
        self.database = database
184
        self.query_engine_class = query_engine_class
185
186
    def setup(self, *items, metadata=None):
187
        return self.database.setup(*items, metadata=metadata)
188
189
    def teardown(self):
190
        return self.database.teardown()
191
192
    def populate(self, *args):
193
        return self.setup(make_orm_models(*args))
194
195
    def query_engine(self, dsn=False, **engine_kwargs):
196
        if dsn is False:
197
            dsn = self.database.host_url()
198
        return self.query_engine_class(dsn, **engine_kwargs)
199
200
    def get_results_tables(self, dataset, **engine_kwargs):
201
        if isinstance(dataset, ql.Dataset):
202
            dataset = dataset._compile()
203
        assert isinstance(dataset, qm.Dataset)
204
        query_engine = self.query_engine(**engine_kwargs)
205
        results_tables = query_engine.get_results_tables(dataset)
206
        # We don't explicitly order the results and not all databases naturally
207
        # return in the same order
208
        return [
209
            [row._asdict() for row in sort_table(table)] for table in results_tables
210
        ]
211
212
    def extract(self, dataset, **engine_kwargs):
213
        return self.get_results_tables(dataset, **engine_kwargs)[0]
214
215
    def dump_dataset_sql(self, dataset, **engine_kwargs):
216
        assert isinstance(dataset, ql.Dataset)
217
        dataset_qm = dataset._compile()
218
        query_engine = self.query_engine(dsn=None, **engine_kwargs)
219
        return get_sql_strings(query_engine, dataset_qm)
220
221
    def sqlalchemy_engine(self):
222
        return self.query_engine().engine
223
224
225
def sort_table(table):
226
    # Python won't naturally compare None with other values, but we need to sort tables
227
    # containg None values so we treat None as smaller than all other values
228
    return sorted(table, key=sort_key_with_nones)
229
230
231
def sort_key_with_nones(row):
232
    return [(v is not None, v) for v in row]
233
234
235
QUERY_ENGINE_NAMES = ("in_memory", "sqlite", "mssql", "trino")
236
237
238
def engine_factory(request, engine_name, with_session_scope=False):
239
    if engine_name == "in_memory":
240
        return QueryEngineFixture(
241
            engine_name, InMemoryPythonDatabase(), InMemoryQueryEngine
242
        )
243
244
    if engine_name == "sqlite":
245
        database_fixture_name = "in_memory_sqlite_database"
246
        query_engine_class = SQLiteQueryEngine
247
    elif engine_name == "mssql":
248
        database_fixture_name = "mssql_database"
249
        query_engine_class = MSSQLQueryEngine
250
    elif engine_name == "trino":
251
        database_fixture_name = "trino_database"
252
        query_engine_class = TrinoQueryEngine
253
    else:
254
        assert False
255
256
    if with_session_scope:
257
        database_fixture_name = f"{database_fixture_name}_with_session_scope"
258
259
    # We dynamically request fixtures rather than making them arguments in the usual way
260
    # so that we only start the database containers we actually need for the test run
261
    database = request.getfixturevalue(database_fixture_name)
262
263
    return QueryEngineFixture(engine_name, database, query_engine_class)
264
265
266
@pytest.fixture(params=QUERY_ENGINE_NAMES)
267
def engine(request):
268
    return engine_factory(request, request.param)
269
270
271
@pytest.fixture
272
def mssql_engine(request):
273
    return engine_factory(request, "mssql")
274
275
276
@pytest.fixture
277
def trino_engine(request):
278
    return engine_factory(request, "trino")
279
280
281
@pytest.fixture
282
def in_memory_engine(request):
283
    return engine_factory(request, "in_memory")
284
285
286
@pytest.fixture
287
def sqlite_engine(request):
288
    return engine_factory(request, "sqlite")
289
290
291
@pytest.fixture(scope="session")
292
def ehrql_image(show_delayed_warning):
293
    project_dir = Path(ehrql.__file__).parents[1]
294
    # Note different name from production image to avoid confusion
295
    image = "ehrql-dev"
296
    # We're deliberately choosing to shell out to the docker client here rather than use
297
    # the docker-py library to avoid possible difference in the build process (docker-py
298
    # doesn't seem to be particularly actively maintained)
299
    with show_delayed_warning(3, f"Building {image} Docker image"):
300
        subprocess.run(
301
            ["docker", "build", project_dir, "-t", image],
302
            check=True,
303
            env=dict(os.environ, DOCKER_BUILDKIT="1"),
304
        )
305
    return f"{image}:latest"
306
307
308
@pytest.fixture(autouse=True)
309
def random_should_not_be_used():
310
    """Asserts that every test should leave the global random number generator unchanged.
311
312
    This is because we want all of our use of randomness to be based on seeded random number
313
    generators, so it should not depend on the global random number generator in any way.
314
    If this is failing, please find all uses of the random module and replace them with a
315
    Random instance, ideally from a predictable seed.
316
    """
317
    prev_state = random.getstate()
318
    yield
319
    assert random.getstate() == prev_state, (
320
        "Global random number generator was used in test."
321
    )
322
323
324
@pytest.fixture
325
def call_cli(capsys):
326
    """
327
    Wrapper around the CLI entrypoint to make it easier to call from tests
328
    """
329
330
    def call(*args, environ=None):
331
        # Convert any Path instances to strings
332
        args = [str(arg) if isinstance(arg, Path) else arg for arg in args]
333
        ehrql.__main__.main(args, environ=environ)
334
        return capsys.readouterr()
335
336
    # Allow reading captured output even when call throws an exception
337
    call.readouterr = capsys.readouterr
338
339
    return call
340
341
342
@pytest.fixture
343
def call_cli_docker(containers, ehrql_image):
344
    """
345
    As above, but invoke the CLI via the Docker image
346
    """
347
348
    def call(*args, environ=None, workspace=None):
349
        args = [
350
            # Make any paths relative to the workspace directory so they still point to
351
            # the right place inside Docker. If you supply path arguments and no
352
            # workspace this will error, as it should. Likewise if you supply paths
353
            # outside of the workspace.
354
            str(arg.relative_to(workspace)) if isinstance(arg, Path) else str(arg)
355
            for arg in args
356
        ]
357
        if workspace is not None:
358
            # Because the files in these directories will need to be readable by
359
            # low-privilege, isolated processes we can't use the standard restrictive
360
            # permissions for temporary directories
361
            workspace.chmod(0o755)
362
            volumes = {workspace: {"bind": "/workspace", "mode": "rw"}}
363
        else:
364
            volumes = {}
365
        return containers.run_captured(
366
            ehrql_image,
367
            command=args,
368
            volumes=volumes,
369
            environment=environ or {},
370
        )
371
372
    return call