Switch to unified view

a b/tests/unit/test___main__.py
1
from pathlib import Path
2
3
import pytest
4
5
from ehrql.__main__ import (
6
    BACKEND_ALIASES,
7
    QUERY_ENGINE_ALIASES,
8
    ArgumentTypeError,
9
    DefinitionError,
10
    FileValidationError,
11
    backend_from_id,
12
    import_string,
13
    main,
14
    query_engine_from_id,
15
    valid_output_path,
16
)
17
from ehrql.backends.base import SQLBackend
18
from ehrql.query_engines.base import BaseQueryEngine
19
from ehrql.query_engines.base_sql import BaseSQLQueryEngine
20
from ehrql.query_engines.debug import DebugQueryEngine
21
from ehrql.query_engines.in_memory import InMemoryQueryEngine
22
from ehrql.utils.module_utils import get_sibling_subclasses
23
24
25
# We just need any old existing file with a ".py" extension for testing purposes, its
26
# contents are immaterial; this one will do
27
DATASET_DEFINITON_PATH = __file__
28
29
30
def test_no_args(capsys):
31
    # Verify that when ehrql is called without arguments, help text is shown.
32
    with pytest.raises(SystemExit):
33
        main([])
34
    captured = capsys.readouterr()
35
    assert "usage: ehrql" in captured.out
36
37
38
def test_generate_dataset(mocker):
39
    # Verify that the generate_dataset subcommand can be invoked.
40
    patched = mocker.patch("ehrql.__main__.generate_dataset")
41
    argv = [
42
        "generate-dataset",
43
        DATASET_DEFINITON_PATH,
44
    ]
45
    main(argv)
46
    patched.assert_called_once()
47
48
49
def test_generate_dataset_rejects_unknown_extension(capsys):
50
    argv = [
51
        "generate-dataset",
52
        DATASET_DEFINITON_PATH,
53
        "--output",
54
        "out_file.badformat",
55
    ]
56
    with pytest.raises(SystemExit):
57
        main(argv)
58
    captured = capsys.readouterr()
59
    assert ".badformat' is not a supported format" in captured.err
60
61
62
def test_generate_dataset_with_definition_error(capsys, mocker):
63
    # Verify that the generate_dataset subcommand can be invoked.
64
    patched = mocker.patch("ehrql.__main__.generate_dataset")
65
    patched.side_effect = DefinitionError("Not a good dataset definition")
66
    argv = [
67
        "generate-dataset",
68
        DATASET_DEFINITON_PATH,
69
    ]
70
    with pytest.raises(SystemExit):
71
        main(argv)
72
    captured = capsys.readouterr()
73
    assert "Not a good dataset definition" in captured.err
74
    assert "Traceback" not in captured.err
75
76
77
def test_generate_dataset_with_validation_error(capsys, mocker):
78
    # Verify that the generate_dataset subcommand can be invoked.
79
    patched = mocker.patch("ehrql.__main__.generate_dataset")
80
    patched.side_effect = FileValidationError("Your file was bad")
81
    argv = [
82
        "generate-dataset",
83
        DATASET_DEFINITON_PATH,
84
    ]
85
    with pytest.raises(SystemExit):
86
        main(argv)
87
    captured = capsys.readouterr()
88
    assert "Your file was bad" in captured.err
89
    assert "Traceback" not in captured.err
90
91
92
def test_dump_dataset_sql(mocker):
93
    # Verify that the dump_dataset_sql subcommand can be invoked.
94
    patched = mocker.patch("ehrql.__main__.dump_dataset_sql")
95
    argv = [
96
        "dump-dataset-sql",
97
        "--backend",
98
        "ehrql.backends.tpp.TPPBackend",
99
        DATASET_DEFINITON_PATH,
100
    ]
101
    main(argv)
102
    patched.assert_called_once()
103
104
105
@pytest.mark.parametrize("output_path", ["dummy_data_path", "dummy_data_path:arrow"])
106
def test_create_dummy_tables(mocker, output_path):
107
    # Verify that the create_dummy_tables subcommand can be invoked.
108
    patched = mocker.patch("ehrql.__main__.create_dummy_tables")
109
    argv = [
110
        "create-dummy-tables",
111
        DATASET_DEFINITON_PATH,
112
        output_path,
113
    ]
114
    main(argv)
115
    patched.assert_called_once()
116
117
118
def test_create_dummy_tables_rejects_unsupported_format(capsys):
119
    argv = [
120
        "create-dummy-tables",
121
        DATASET_DEFINITON_PATH,
122
        "dummy_data_path:invalid",
123
    ]
124
    with pytest.raises(SystemExit):
125
        main(argv)
126
    captured = capsys.readouterr()
127
    assert "':invalid' is not a supported format" in captured.err
128
129
130
def test_generate_measures(mocker):
131
    # Verify that the generate_measures subcommand can be invoked.
132
    patched = mocker.patch("ehrql.__main__.generate_measures")
133
    argv = [
134
        "generate-measures",
135
        DATASET_DEFINITON_PATH,
136
    ]
137
    main(argv)
138
    patched.assert_called_once()
139
140
141
def test_existing_python_file_missing_file(capsys, tmp_path):
142
    # Verify that a helpful message is shown when a command is invoked with a path to a
143
    # file that should exist but doesn't.
144
    dataset_definition_path = tmp_path / "dataset.py"
145
    argv = [
146
        "generate-dataset",
147
        str(dataset_definition_path),
148
    ]
149
    with pytest.raises(SystemExit):
150
        main(argv)
151
    captured = capsys.readouterr()
152
    assert "dataset.py does not exist" in captured.err
153
154
155
def test_existing_python_file_unpythonic_file(capsys, tmp_path):
156
    # Verify that a helpful message is shown when a command is invoked with a path to a
157
    # file that should be a Python file but isn't.
158
    dataset_definition_path = tmp_path / "dataset.cpp"
159
    dataset_definition_path.touch()
160
    argv = [
161
        "generate-dataset",
162
        str(dataset_definition_path),
163
    ]
164
    with pytest.raises(SystemExit):
165
        main(argv)
166
    captured = capsys.readouterr()
167
    assert "dataset.cpp is not a Python file" in captured.err
168
169
170
def test_existing_directory_missing_directory(capsys, tmp_path):
171
    dataset_definition_path = tmp_path / "dataset.py"
172
    dataset_definition_path.touch()
173
    argv = [
174
        "generate-dataset",
175
        str(dataset_definition_path),
176
        "--dummy-tables",
177
        "non-existent-directory",
178
    ]
179
    with pytest.raises(SystemExit):
180
        main(argv)
181
    captured = capsys.readouterr()
182
    assert "non-existent-directory does not exist" in captured.err
183
184
185
def test_existing_directory_not_a_directory(capsys, tmp_path):
186
    dataset_definition_path = tmp_path / "dataset.py"
187
    dataset_definition_path.touch()
188
    file_path = tmp_path / "not-a-directory.file"
189
    file_path.touch()
190
    argv = [
191
        "generate-dataset",
192
        str(dataset_definition_path),
193
        "--dummy-tables",
194
        str(file_path),
195
    ]
196
    with pytest.raises(SystemExit):
197
        main(argv)
198
    captured = capsys.readouterr()
199
    assert "not-a-directory.file is not a directory" in captured.err
200
201
202
def test_existing_file_missing_file(capsys, tmp_path):
203
    dataset_definition_path = tmp_path / "dataset.py"
204
    dataset_definition_path.touch()
205
    argv = [
206
        "generate-dataset",
207
        str(dataset_definition_path),
208
        "--dummy-data-file",
209
        "non-existent-file",
210
    ]
211
    with pytest.raises(SystemExit):
212
        main(argv)
213
    captured = capsys.readouterr()
214
    assert "non-existent-file does not exist" in captured.err
215
216
217
def test_existing_file_not_a_file(capsys, tmp_path):
218
    dataset_definition_path = tmp_path / "dataset.py"
219
    dataset_definition_path.touch()
220
    directory_path = tmp_path / "not-a-file"
221
    directory_path.mkdir()
222
    argv = [
223
        "generate-dataset",
224
        str(dataset_definition_path),
225
        "--dummy-data-file",
226
        str(directory_path),
227
    ]
228
    with pytest.raises(SystemExit):
229
        main(argv)
230
    captured = capsys.readouterr()
231
    assert "not-a-file is not a file" in captured.err
232
233
234
def test_import_string():
235
    assert import_string("ehrql.__main__.main") is main
236
237
238
def test_import_string_not_a_dotted_path():
239
    with pytest.raises(ArgumentTypeError, match="must be a full dotted path"):
240
        import_string("urllib")
241
242
243
def test_import_string_no_such_module():
244
    with pytest.raises(ArgumentTypeError, match="could not import module"):
245
        import_string("urllib.this_is_not_a_module.Foo")
246
247
248
def test_import_string_no_such_attribute():
249
    with pytest.raises(ArgumentTypeError, match="'urllib.parse' has no attribute"):
250
        import_string("urllib.parse.ThisIsNotAClass")
251
252
253
class DummyQueryEngine:
254
    def get_results_tables(self):
255
        raise NotImplementedError()
256
257
258
def test_query_engine_from_id():
259
    engine_id = f"{DummyQueryEngine.__module__}.{DummyQueryEngine.__name__}"
260
    assert query_engine_from_id(engine_id) is DummyQueryEngine
261
262
263
def test_query_engine_from_id_missing_alias():
264
    with pytest.raises(ArgumentTypeError, match="must be one of"):
265
        query_engine_from_id("missing")
266
267
268
def test_query_engine_from_id_wrong_type():
269
    with pytest.raises(ArgumentTypeError, match="is not a valid query engine"):
270
        query_engine_from_id("pathlib.Path")
271
272
273
class DummyBackend:
274
    def get_table_expression(self):
275
        raise NotImplementedError()
276
277
278
def test_backend_from_id():
279
    engine_id = f"{DummyBackend.__module__}.{DummyBackend.__name__}"
280
    assert backend_from_id(engine_id) is DummyBackend
281
282
283
def test_backend_from_id_missing_alias():
284
    with pytest.raises(ArgumentTypeError, match="must be one of"):
285
        backend_from_id("missing")
286
287
288
def test_backend_from_id_wrong_type():
289
    with pytest.raises(ArgumentTypeError, match="is not a valid backend"):
290
        backend_from_id("pathlib.Path")
291
292
293
@pytest.mark.parametrize("alias", ["expectations", "test"])
294
def test_backend_from_id_special_case_aliases(alias):
295
    assert backend_from_id(alias) is None
296
297
298
def test_all_query_engine_aliases_are_importable():
299
    for alias in QUERY_ENGINE_ALIASES.keys():
300
        assert query_engine_from_id(alias)
301
302
303
def test_all_backend_aliases_are_importable():
304
    for alias in BACKEND_ALIASES.keys():
305
        assert backend_from_id(alias)
306
307
308
def test_all_query_engines_have_an_alias():
309
    for cls in get_sibling_subclasses(BaseQueryEngine):
310
        if cls in [
311
            BaseSQLQueryEngine,
312
            InMemoryQueryEngine,
313
            DebugQueryEngine,
314
        ]:
315
            continue
316
        name = f"{cls.__module__}.{cls.__name__}"
317
        assert name in QUERY_ENGINE_ALIASES.values(), f"No alias defined for '{name}'"
318
319
320
def test_all_backends_have_an_alias():
321
    for cls in get_sibling_subclasses(SQLBackend):
322
        name = f"{cls.__module__}.{cls.__name__}"
323
        assert name in BACKEND_ALIASES.values(), f"No alias defined for '{name}'"
324
325
326
def test_all_backend_aliases_match_display_names():
327
    for alias in BACKEND_ALIASES.keys():
328
        assert backend_from_id(alias).display_name.lower() == alias
329
330
331
@pytest.mark.parametrize(
332
    "path",
333
    [
334
        "some/path/file.csv",
335
        "some/path/dir:csv",
336
        "some/path/dir/:csv",
337
        "some/path/dir.foo:csv",
338
    ],
339
)
340
def test_valid_output_path(path):
341
    assert valid_output_path(path) == Path(path)
342
343
344
@pytest.mark.parametrize(
345
    "path, message",
346
    [
347
        ("no/extension", "No file format supplied"),
348
        ("some/path.badfile", "'.badfile' is not a supported format"),
349
        ("some/path:baddir", "':baddir' is not a supported format"),
350
        ("some/path/:baddir", "':baddir' is not a supported format"),
351
    ],
352
)
353
def test_valid_output_path_errors(path, message):
354
    with pytest.raises(ArgumentTypeError, match=message):
355
        valid_output_path(path)