a b/tests/unit/test_main.py
1
import dataclasses
2
3
from ehrql.main import get_query_engine, open_output_file
4
5
6
@dataclasses.dataclass
7
class DummyQueryEngine:
8
    dsn: str
9
    backend: object
10
    config: dict
11
12
13
@dataclasses.dataclass
14
class DefaultQueryEngine:
15
    dsn: str
16
    backend: object
17
    config: dict
18
19
20
@dataclasses.dataclass
21
class DummyBackend:
22
    config: dict
23
    query_engine_class = DummyQueryEngine
24
25
26
def test_get_query_engine_defaults():
27
    query_engine = get_query_engine(
28
        dsn=None,
29
        backend_class=None,
30
        query_engine_class=None,
31
        environ={},
32
        default_query_engine_class=DefaultQueryEngine,
33
    )
34
    assert isinstance(query_engine, DefaultQueryEngine)
35
36
37
def test_get_query_engine_with_query_engine():
38
    query_engine = get_query_engine(
39
        dsn=None,
40
        backend_class=None,
41
        query_engine_class=DummyQueryEngine,
42
        environ={},
43
        default_query_engine_class=None,
44
    )
45
    assert isinstance(query_engine, DummyQueryEngine)
46
    assert query_engine.backend is None
47
    assert query_engine.config == {}
48
49
50
def test_get_query_engine_with_backend():
51
    query_engine = get_query_engine(
52
        dsn=None,
53
        backend_class=DummyBackend,
54
        query_engine_class=None,
55
        environ={"foo": "bar"},
56
        default_query_engine_class=None,
57
    )
58
    assert isinstance(query_engine, DummyQueryEngine)
59
    assert isinstance(query_engine.backend, DummyBackend)
60
    assert query_engine.config == {"foo": "bar"}
61
    assert query_engine.backend.config == {"foo": "bar"}
62
63
64
def test_get_query_engine_with_backend_and_query_engine():
65
    query_engine = get_query_engine(
66
        dsn=None,
67
        backend_class=DummyBackend,
68
        query_engine_class=DefaultQueryEngine,
69
        environ={},
70
        default_query_engine_class=None,
71
    )
72
    assert isinstance(query_engine, DefaultQueryEngine)
73
    assert isinstance(query_engine.backend, DummyBackend)
74
    assert query_engine.config == {}
75
76
77
def test_open_output_file(tmp_path):
78
    test_file = tmp_path / "testdir" / "file.txt"
79
    with open_output_file(test_file) as f:
80
        f.write("hello")
81
    assert test_file.read_text() == "hello"
82
83
84
def test_open_output_file_with_stdout(capsys):
85
    with open_output_file(None) as f:
86
        f.write("hello")
87
    assert capsys.readouterr().out == "hello"