Diff of /tests/lib/databases.py [000000] .. [e988c2]

Switch to side-by-side view

--- a
+++ b/tests/lib/databases.py
@@ -0,0 +1,331 @@
+import secrets
+import time
+from pathlib import Path
+
+import sqlalchemy
+import sqlalchemy.exc
+from packaging.version import parse as version_parse
+from requests.exceptions import ConnectionError  # noqa A004
+from sqlalchemy.dialects import registry
+from sqlalchemy.orm import sessionmaker
+from trino.exceptions import TrinoQueryError
+
+from ehrql.query_engines.in_memory_database import InMemoryDatabase
+from ehrql.utils.itertools_utils import iter_flatten
+from tests.lib.orm_utils import SYNTHETIC_PRIMARY_KEY, table_has_one_row_per_patient
+
+
+MSSQL_SETUP_DIR = Path(__file__).parents[1].absolute() / "support/mssql"
+TRINO_SETUP_DIR = Path(__file__).parents[1].absolute() / "support/trino"
+
+
+# Register our modified SQLAlchemy dialects
+registry.register(
+    "sqlite.pysqlite.opensafely",
+    "ehrql.query_engines.sqlite_dialect",
+    "SQLiteDialect",
+)
+
+registry.register(
+    "trino.opensafely", "ehrql.query_engines.trino_dialect", "TrinoDialect"
+)
+
+
+class DbDetails:
+    def __init__(
+        self,
+        protocol,
+        driver,
+        host_from_container,
+        port_from_container,
+        host_from_host,
+        port_from_host,
+        username="",
+        password="",
+        db_name="",
+        query=None,
+        temp_db=None,
+        engine_kwargs=None,
+    ):
+        self.protocol = protocol
+        self.driver = driver
+        self.host_from_container = host_from_container
+        self.port_from_container = port_from_container
+        self.host_from_host = host_from_host
+        self.port_from_host = port_from_host
+        self.password = password
+        self.username = username
+        self.db_name = db_name
+        self.query = query
+        self.temp_db = temp_db
+        self.engine_kwargs = engine_kwargs or {}
+        self.metadata = None
+
+    def container_url(self):
+        return self._url(self.host_from_container, self.port_from_container)
+
+    def host_url(self):
+        return self._url(self.host_from_host, self.port_from_host)
+
+    def engine(self, dialect=None, **kwargs):
+        url = self._url(
+            self.host_from_host, self.port_from_host, include_driver=bool(self.driver)
+        )
+        engine_url = sqlalchemy.engine.make_url(url)
+        engine_kwargs = self.engine_kwargs | kwargs
+        engine = sqlalchemy.create_engine(engine_url, **engine_kwargs)
+        return engine
+
+    def _url(self, host, port, include_driver=False):
+        assert self.username
+        if self.username and self.password:
+            auth = f"{self.username}:{self.password}@"
+        else:
+            auth = f"{self.username}@"
+        if include_driver:
+            protocol = f"{self.protocol}+{self.driver}"
+        else:
+            protocol = self.protocol
+        url = f"{protocol}://{auth}{host}:{port}/{self.db_name}"
+        return url
+
+    def setup(self, *input_data, metadata=None):
+        """
+        Accepts SQLAlchemy ORM objects (which may be arbitrarily nested within lists and
+        tuples), creates the necessary tables and inserts them into the database
+        """
+        input_data = list(iter_flatten(input_data))
+        engine = self.engine()
+        Session = sessionmaker()
+        Session.configure(bind=engine)
+        session = Session()
+
+        if metadata:
+            pass
+        elif input_data:
+            metadata = input_data[0].metadata
+        else:
+            assert False, "No source of metadata"
+        assert all(item.metadata is metadata for item in input_data)
+
+        self.metadata = metadata
+        metadata.create_all(engine)
+        session.bulk_save_objects(
+            input_data,
+            return_defaults=False,
+            update_changed_only=False,
+            preserve_order=False,
+        )
+        session.commit()
+
+    def teardown(self):
+        if self.metadata is not None:
+            self.metadata.drop_all(self.engine())
+
+
+def wait_for_database(database, timeout=20):
+    engine = database.engine()
+    start = time.time()
+    limit = start + timeout
+    while True:
+        try:
+            with engine.connect() as connection:
+                connection.execute(sqlalchemy.text("SELECT 'hello'"))
+            break
+        except (
+            sqlalchemy.exc.OperationalError,
+            ConnectionRefusedError,
+            ConnectionResetError,
+            BrokenPipeError,
+            ConnectionError,
+            TrinoQueryError,
+            sqlalchemy.exc.DBAPIError,
+        ) as e:  # pragma: no cover
+            if time.time() >= limit:
+                raise Exception(
+                    f"Failed to connect to database after {timeout} seconds: "
+                    f"{engine.url}"
+                ) from e
+            time.sleep(1)
+
+
+def make_mssql_database(containers):
+    password = "Your_password123!"
+
+    container_name = "ehrql-mssql"
+    mssql_port = 1433
+
+    if not containers.is_running(container_name):  # pragma: no cover
+        run_mssql(container_name, containers, password, mssql_port)
+
+    container_ip = containers.get_container_ip(container_name)
+    host_mssql_port = containers.get_mapped_port_for_host(container_name, mssql_port)
+
+    return DbDetails(
+        protocol="mssql",
+        driver="pymssql",
+        host_from_container=container_ip,
+        port_from_container=mssql_port,
+        host_from_host="localhost",
+        port_from_host=host_mssql_port,
+        username="sa",
+        password=password,
+        db_name="test",
+    )
+
+
+def run_mssql(container_name, containers, password, mssql_port):  # pragma: no cover
+    containers.run_bg(
+        name=container_name,
+        # This is *not* the version that TPP run for us in production which, as of
+        # 2024-09-24, is SQL Server 2016 (13.0.5893.48). That version is not available
+        # as a Docker image, so we run the oldest supported version instead. Both the
+        # production server and our test server set the "compatibility level" to the
+        # same value so the same feature set should be supported.
+        image="mcr.microsoft.com/mssql/server:2019-CU28-ubuntu-20.04",
+        volumes={
+            MSSQL_SETUP_DIR: {"bind": "/mssql", "mode": "ro"},
+        },
+        # Choose an arbitrary free port to publish the MSSQL port on
+        ports={mssql_port: None},
+        environment={
+            "MSSQL_SA_PASSWORD": password,
+            "ACCEPT_EULA": "Y",
+            "MSSQL_TCP_PORT": str(mssql_port),
+            # Make all string comparisons case-sensitive across all databases
+            "MSSQL_COLLATION": "SQL_Latin1_General_CP1_CS_AS",
+        },
+        user="root",
+        entrypoint="/mssql/entrypoint.sh",
+        command="/opt/mssql/bin/sqlservr",
+    )
+
+
+class InMemorySQLiteDatabase(DbDetails):
+    def __init__(self):
+        db_name = secrets.token_hex(8)
+        super().__init__(
+            db_name=db_name,
+            protocol="sqlite",
+            driver="pysqlite+opensafely",
+            host_from_container=None,
+            port_from_container=None,
+            host_from_host=None,
+            port_from_host=None,
+        )
+        self._engine = None
+
+    def engine(self, dialect=None, **kwargs):
+        # We need to hold a reference to the engine for the lifetime of this database to stop the contents of the
+        # database from being garbage-collected.
+        if not self._engine:
+            self._engine = super().engine(dialect, **kwargs)
+        return self._engine
+
+    def _url(self, host, port, include_driver=False):
+        if include_driver:
+            protocol = f"{self.protocol}+{self.driver}"
+        else:
+            protocol = self.protocol
+        # https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#uri-connections
+        # https://sqlite.org/inmemorydb.html
+        return f"{protocol}:///file:{self.db_name}?mode=memory&cache=shared&uri=true"
+
+
+class InMemoryPythonDatabase:
+    def __init__(self):
+        self.database = InMemoryDatabase()
+
+    def setup(self, *input_data, metadata=None):
+        """
+        Behaves like `DbDetails.setup` in taking a iterator of ORM instances but
+        translates these into the sort of objects needed by the `InMemoryDatabase`
+        """
+        input_data = list(iter_flatten(input_data))
+
+        if metadata:
+            pass
+        elif input_data:
+            metadata = input_data[0].metadata
+        else:
+            assert False, "No source of metadata"
+        assert all(item.metadata is metadata for item in input_data)
+
+        sqla_table_to_items = {table: [] for table in metadata.sorted_tables}
+        for item in input_data:
+            sqla_table_to_items[item.__table__].append(item)
+
+        for sqla_table, items in sqla_table_to_items.items():
+            columns = [
+                c.name for c in sqla_table.columns if c.name != SYNTHETIC_PRIMARY_KEY
+            ]
+            self.database.add_table(
+                name=sqla_table.name,
+                one_row_per_patient=table_has_one_row_per_patient(sqla_table),
+                columns=columns,
+                rows=[[getattr(item, c) for c in columns] for item in items],
+            )
+
+    def teardown(self):
+        self.database.populate({})
+
+    def host_url(self):
+        # Where other query engines expect a DSN string to connect to the database the
+        # InMemoryQueryEngine expects a reference to the database object itself
+        return self.database
+
+
+def make_trino_database(containers):
+    container_name = "ehrql-trino"
+    trino_port = 8080
+
+    if not containers.is_running(container_name):  # pragma: no cover
+        run_trino(container_name, containers, trino_port)
+
+    container_ip = containers.get_container_ip(container_name)
+    host_trino_port = containers.get_mapped_port_for_host(container_name, trino_port)
+
+    return DbDetails(
+        protocol="trino",
+        driver="opensafely",
+        host_from_container=container_ip,
+        port_from_container=trino_port,
+        host_from_host="localhost",
+        port_from_host=host_trino_port,
+        username="trino",
+        db_name="trino/default",
+        # Disable automatic retries for the test client: it's pointless and creates log
+        # noise
+        engine_kwargs={"connect_args": {"max_attempts": 1}},
+    )
+
+
+def run_trino(container_name, containers, trino_port):  # pragma: no cover
+    # Note, I don't actually know that this is the minimum required version of Docker
+    # Engine. I do know that 20.10.5 is unsupported (because that's what I had
+    # installed) and that 20.10.16 is supported, according to this comment:
+    # https://github.com/adoptium/containers/issues/214#issuecomment-1139464798 which
+    # was linked from this issue: https://github.com/trinodb/trino/issues/14269
+    min_docker_version = "20.10.16"
+    docker_version = containers.get_engine_version()
+    assert version_parse(docker_version) >= version_parse(min_docker_version), (
+        f"The Trino Docker image requires Docker Engine v{min_docker_version}"
+        f" or above but you have v{docker_version}"
+    )
+    containers.run_bg(
+        name=container_name,
+        # This is the version which happened to be current at the time of writing and is
+        # pinned for reproduciblity's sake rather than because there's anything
+        # significant about it
+        image="trinodb/trino:440",
+        volumes={
+            TRINO_SETUP_DIR: {"bind": "/trino", "mode": "ro"},
+            f"{TRINO_SETUP_DIR}/etc": {"bind": "/etc/trino", "mode": "ro"},
+        },
+        # Choose an arbitrary free port to publish the trino port on
+        ports={trino_port: None},
+        environment={},
+        user="root",
+        entrypoint="/trino/entrypoint.sh",
+        command="/usr/lib/trino/bin/run-trino",
+    )