--- a +++ b/tests/integration/backends/helpers.py @@ -0,0 +1,88 @@ +import sqlalchemy + +from ehrql.query_language import get_tables_from_namespace +from ehrql.tables import emis, tpp +from ehrql.tables.raw import emis as emis_raw +from ehrql.tables.raw import tpp as tpp_raw + + +REGISTERED_TABLES = set() + + +# This slightly odd way of supplying the table object to the test function makes the +# tests introspectable in such a way that we can confirm that every table in the module +# is covered by a test +def register_test_for(table): + def annotate_test_function(fn): + REGISTERED_TABLES.add(table) + fn._table = table + return fn + + return annotate_test_function + + +def assert_types_correct(columns_with_types, database): + mismatched = [ + f"{table}.{column} expects {column_type!r} but got {column_args!r}" + for table, column, column_type, column_args in columns_with_types + if not types_compatible(database, column_type, column_args) + ] + nl = "\n" + assert not mismatched, ( + f"Mismatch between columns returned by backend queries" + f" queries and those expected:\n{nl.join(mismatched)}\n\n" + ) + + +def types_compatible(database, column_type, column_args): + """ + Is this given SQLAlchemy type instance compatible with the supplied dictionary of + column arguments? + """ + # It seems we use this sometimes for the patient ID column where we don't care what + # type it is + if isinstance(column_type, sqlalchemy.sql.sqltypes.NullType): + return True + elif isinstance(column_type, sqlalchemy.Boolean): + # MSSQL doesn't have a boolean type so we expect an int here + if database.protocol == "mssql": + return column_args["type"] == "int" + # Current no non-mssql backends (i.e. emis) have boolean column types + return column_args["type"] == "boolean" # pragma: no cover + elif isinstance(column_type, sqlalchemy.Integer): + return column_args["type"] in ("int", "integer", "bigint") + elif isinstance(column_type, sqlalchemy.Float): + return column_args["type"] == "real" + elif isinstance(column_type, sqlalchemy.Date): + return column_args["type"] == "date" + elif isinstance(column_type, sqlalchemy.String): + return ( + column_args["type"].startswith("varchar") + and column_args.get("collation") == column_type.collation + ) + else: + assert False, f"Unhandled type: {column_type}" + + +def get_all_backend_columns(backend): + for _, table in get_all_tables(backend): + qm_table = table._qm_node + table_expr = backend.get_table_expression(qm_table.name, qm_table.schema) + yield qm_table.name, table_expr.columns + + +def get_all_tables(backend): + table_modules_by_backend = {"emis": [emis, emis_raw], "tpp": [tpp, tpp_raw]} + modules = table_modules_by_backend[backend.display_name.lower()] + for module in modules: + for name, table in get_tables_from_namespace(module): + yield f"{module.__name__}.{name}", table + + +def assert_tests_exhaustive(backend): + missing = [ + name + for name, table in get_all_tables(backend) + if table not in REGISTERED_TABLES + ] + assert not missing, f"No tests for tables: {', '.join(missing)}"