[e988c2]: / tests / unit / file_formats / test_main.py

Download this file

117 lines (92 with data), 3.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from pathlib import Path
import pytest
from ehrql.file_formats.arrow import ArrowRowsReader
from ehrql.file_formats.csv import CSVGZRowsReader, CSVRowsReader
from ehrql.file_formats.main import (
FileValidationError,
get_extension_from_directory,
get_file_extension,
get_table_filename,
read_rows,
split_directory_and_extension,
)
@pytest.mark.parametrize(
"filename,extension",
[
(Path("a/b.c/file.txt"), ".txt"),
(Path("a/b.c/file.txt.foo"), ".foo"),
(Path("a/b.c/file.txt.gz"), ".txt.gz"),
(Path("a/b.c/file"), ""),
],
)
def test_get_file_extension(filename, extension):
assert get_file_extension(filename) == extension
def test_read_rows_rejects_unsupported_file_types():
with pytest.raises(FileValidationError, match="Unsupported file type: .xyz"):
read_rows(Path("some_file.xyz"), {})
def test_read_rows_raises_error_for_missing_files():
missing_file = Path(__file__).parent / "no_such_file.csv"
with pytest.raises(FileValidationError, match=f"Missing file: {missing_file}"):
read_rows(missing_file, {})
@pytest.mark.parametrize(
"reader_class",
[
CSVRowsReader,
CSVGZRowsReader,
ArrowRowsReader,
],
)
def test_rows_reader_constructor_rejects_non_path(reader_class):
with pytest.raises(FileValidationError, match="must be a pathlib.Path instance"):
reader_class("some/string/path", {})
def test_get_extension_from_directory(tmp_path):
directory = tmp_path / "some_dir"
directory.mkdir()
(directory / "file_a.csv.gz").touch()
(directory / "file_b.csv.gz").touch()
(directory / "README.txt").touch()
assert get_extension_from_directory(directory) == ".csv.gz"
def test_get_extension_from_directory_missing(tmp_path):
with pytest.raises(FileValidationError, match="Missing directory"):
get_extension_from_directory(tmp_path / "no_such_dir")
def test_get_extension_from_directory_with_wrong_type(tmp_path):
directory = tmp_path / "not_a_dir"
directory.touch()
with pytest.raises(FileValidationError, match="Not a directory"):
get_extension_from_directory(directory)
def test_get_extension_from_directory_without_supported_extensions(tmp_path):
directory = tmp_path / "some_dir"
directory.mkdir()
(directory / "file_a.jpg").touch()
(directory / "file_b.docx").touch()
with pytest.raises(FileValidationError, match="No supported file formats found"):
get_extension_from_directory(directory)
def test_get_extension_from_directory_with_ambiguous_extensions(tmp_path):
directory = tmp_path / "some_dir"
directory.mkdir()
(directory / "file_a.csv").touch()
(directory / "file_b.arrow").touch()
with pytest.raises(
FileValidationError,
match=r"Found multiple file formats \(\.arrow, \.csv\)",
):
get_extension_from_directory(directory)
def test_get_table_filename_escapes_problematic_characters():
filename = get_table_filename(
Path("parent"),
"bad/ table /name/",
".csv",
)
assert filename == Path("parent/bad%2F%20table%20%2Fname%2F.csv")
@pytest.mark.parametrize(
"filename,expected_dir,expected_ext",
[
("some/dir:csv", "some/dir", ".csv"),
("some/dir", "some/dir", ""),
("some/dir/:csv", "some/dir", ".csv"),
],
)
def test_split_directory_and_extension(filename, expected_dir, expected_ext):
directory, extension = split_directory_and_extension(Path(filename))
assert directory == Path(expected_dir)
assert extension == expected_ext