[e988c2]: / tests / unit / test_main.py

Download this file

88 lines (68 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import dataclasses
from ehrql.main import get_query_engine, open_output_file
@dataclasses.dataclass
class DummyQueryEngine:
dsn: str
backend: object
config: dict
@dataclasses.dataclass
class DefaultQueryEngine:
dsn: str
backend: object
config: dict
@dataclasses.dataclass
class DummyBackend:
config: dict
query_engine_class = DummyQueryEngine
def test_get_query_engine_defaults():
query_engine = get_query_engine(
dsn=None,
backend_class=None,
query_engine_class=None,
environ={},
default_query_engine_class=DefaultQueryEngine,
)
assert isinstance(query_engine, DefaultQueryEngine)
def test_get_query_engine_with_query_engine():
query_engine = get_query_engine(
dsn=None,
backend_class=None,
query_engine_class=DummyQueryEngine,
environ={},
default_query_engine_class=None,
)
assert isinstance(query_engine, DummyQueryEngine)
assert query_engine.backend is None
assert query_engine.config == {}
def test_get_query_engine_with_backend():
query_engine = get_query_engine(
dsn=None,
backend_class=DummyBackend,
query_engine_class=None,
environ={"foo": "bar"},
default_query_engine_class=None,
)
assert isinstance(query_engine, DummyQueryEngine)
assert isinstance(query_engine.backend, DummyBackend)
assert query_engine.config == {"foo": "bar"}
assert query_engine.backend.config == {"foo": "bar"}
def test_get_query_engine_with_backend_and_query_engine():
query_engine = get_query_engine(
dsn=None,
backend_class=DummyBackend,
query_engine_class=DefaultQueryEngine,
environ={},
default_query_engine_class=None,
)
assert isinstance(query_engine, DefaultQueryEngine)
assert isinstance(query_engine.backend, DummyBackend)
assert query_engine.config == {}
def test_open_output_file(tmp_path):
test_file = tmp_path / "testdir" / "file.txt"
with open_output_file(test_file) as f:
f.write("hello")
assert test_file.read_text() == "hello"
def test_open_output_file_with_stdout(capsys):
with open_output_file(None) as f:
f.write("hello")
assert capsys.readouterr().out == "hello"