Diff of /tests/unit/test_tables.py [000000] .. [e988c2]

Switch to side-by-side view

--- a
+++ b/tests/unit/test_tables.py
@@ -0,0 +1,82 @@
+import re
+
+import pytest
+
+import ehrql.tables
+from ehrql.query_language import get_tables_from_namespace
+from ehrql.tables import Constraint, tpp
+from ehrql.utils.module_utils import get_submodules
+
+
+@pytest.mark.parametrize("module", list(get_submodules(ehrql.tables)))
+def test___all__(module):
+    table_names = {name for name, _ in get_tables_from_namespace(module)}
+    if not table_names:
+        pytest.skip(f"{module.__name__} has no tables")
+    assert module.__all__ == sorted(module.__all__)
+    assert table_names == set(module.__all__)
+
+
+valid_examples_for_regex_constraints = [
+    (tpp.addresses, "msoa_code", "E02012345"),
+    (tpp.ec, "sus_hrg_code", "AA00A"),
+    (tpp.opa, "hrg_code", "AA00A"),
+    (tpp.practice_registrations, "practice_stp", "E54000012"),
+    (tpp.wl_clockstops, "activity_treatment_function_code", "AB1"),
+    (tpp.wl_openpathways, "activity_treatment_function_code", "AB1"),
+    (tpp.wl_openpathways, "source_of_referral", "A1"),
+]
+
+invalid_examples_for_regex_constraints = [
+    (tpp.addresses, "msoa_code", "X02012345"),
+    (tpp.ec, "sus_hrg_code", "AA000A"),
+    (tpp.opa, "hrg_code", "AA000A"),
+    (tpp.practice_registrations, "practice_stp", "X54000012"),
+    (tpp.wl_clockstops, "activity_treatment_function_code", "AB10"),
+    (tpp.wl_openpathways, "activity_treatment_function_code", "AB10"),
+    (tpp.wl_openpathways, "source_of_referral", "A10"),
+]
+
+
+@pytest.mark.parametrize(
+    "table,column_name,example", valid_examples_for_regex_constraints
+)
+def test_regex_constraints_match_valid_examples(table, column_name, example):
+    regex = get_regex_constraint(table, column_name)
+    assert bool(re.fullmatch(regex, example))
+
+
+@pytest.mark.parametrize(
+    "table,column_name,example", invalid_examples_for_regex_constraints
+)
+def test_regex_constraints_do_not_match_invalid_examples(table, column_name, example):
+    regex = get_regex_constraint(table, column_name)
+    assert not bool(re.fullmatch(regex, example))
+
+
+@pytest.mark.parametrize(
+    "examples",
+    [valid_examples_for_regex_constraints, invalid_examples_for_regex_constraints],
+)
+def test_all_regex_constraints_are_tested(examples):
+    tables = {
+        table
+        for module in get_submodules(ehrql.tables)
+        for _, table in get_tables_from_namespace(module)
+    }
+    tables_column_names = {
+        (table, column_name)
+        for table in tables
+        for column_name in table._qm_node.schema.column_names
+        if get_regex_constraint(table, column_name)
+    }
+
+    assert tables_column_names == {
+        (table, column_name) for table, column_name, _ in examples
+    }
+
+
+def get_regex_constraint(table, column_name):
+    schema = table._qm_node.schema
+    constraint = schema.get_column_constraint_by_type(column_name, Constraint.Regex)
+    return getattr(constraint, "regex", None)