a b/tests/integration/backends/conftest.py
1
import pytest
2
import sqlalchemy
3
4
from ehrql.backends.emis import EMISBackend
5
from ehrql.backends.tpp import TPPBackend
6
7
8
def _get_select_all_query(request, backend):
9
    try:
10
        ql_table = request.function._table
11
    except AttributeError:  # pragma: no cover
12
        raise RuntimeError(
13
            f"Function '{request.function.__name__}' needs the "
14
            f"`@register_test_for(table)` decorator applied"
15
        )
16
17
    qm_table = ql_table._qm_node
18
    sql_table = backend.get_table_expression(qm_table.name, qm_table.schema)
19
    columns = [
20
        # Using `type_coerce(..., None)` like this strips the type information from the
21
        # SQLAlchemy column meaning we get back the type that the column actually is in
22
        # database, not the type we've told SQLAlchemy it is.
23
        sqlalchemy.type_coerce(column, None).label(column.key)
24
        for column in sql_table.columns
25
    ]
26
    return sqlalchemy.select(*columns)
27
28
29
def _select_all_fn(select_all_query, database):
30
    def _select_all(*input_data):
31
        database.setup(*input_data)
32
        with database.engine().connect() as connection:
33
            results = connection.execute(select_all_query)
34
            return sorted(
35
                [row._asdict() for row in results], key=lambda x: x["patient_id"]
36
            )
37
38
    return _select_all
39
40
41
@pytest.fixture
42
def select_all_emis(request, trino_database):
43
    select_all_query = _get_select_all_query(request, EMISBackend())
44
    return _select_all_fn(select_all_query, trino_database)
45
46
47
@pytest.fixture
48
def select_all_tpp(request, mssql_database):
49
    backend = TPPBackend(config={"TEMP_DATABASE_NAME": "temp_tables"})
50
    select_all_query = _get_select_all_query(request, backend)
51
    return _select_all_fn(select_all_query, mssql_database)