--- a +++ b/tests/lib/update_tpp_schema.py @@ -0,0 +1,209 @@ +import csv +import keyword +import re +import subprocess +import sys +from pathlib import Path +from urllib.parse import urljoin, urlparse + +import requests + + +SERVER_URL = "https://jobs.opensafely.org" +WORKSPACE_NAME = "tpp-database-schema" +OUTPUTS_INDEX_URL = f"{SERVER_URL}/opensafely-internal/{WORKSPACE_NAME}/outputs/" + +SCHEMA_DIR = Path(__file__).parent +SCHEMA_CSV = SCHEMA_DIR / "tpp_schema.csv" +SCHEMA_PYTHON = SCHEMA_DIR / "tpp_schema.py" +DATA_DICTIONARY_CSV = SCHEMA_DIR / "tpp_data_dictionary.csv" +DECISION_SUPPORT_REF_CSV = SCHEMA_DIR / "tpp_decision_support_reference.csv" +CATEGORICAL_COLUMNS_CSV = SCHEMA_DIR / "tpp_categorical_columns.csv" + +TYPE_MAP = { + "bit": (0, lambda _: "t.Boolean"), + "tinyint": (0, lambda _: "t.SMALLINT"), + "int": (0, lambda _: "t.Integer"), + "bigint": (0, lambda _: "t.BIGINT"), + "numeric": (0, lambda _: "t.Numeric"), + "float": (0.0, lambda _: "t.Float"), + "real": (0.0, lambda _: "t.REAL"), + "date": ("9999-12-31", lambda _: "t.Date"), + "time": ("00:00:00", lambda _: "t.Time"), + "datetime": ("9999-12-31T00:00:00", lambda _: "t.DateTime"), + "char": ("", lambda col: format_string_type("t.CHAR", col)), + "varchar": ("", lambda col: format_string_type("t.VARCHAR", col)), + "varbinary": (b"", lambda col: format_binary_type("t.VARBINARY", col)), +} + + +HEADER = """\ +# This file is auto-generated: DO NOT EDIT IT +# +# To rebuild run: +# +# python tests/lib/update_tpp_schema.py build +# + +from sqlalchemy import types as t +from sqlalchemy.orm import DeclarativeBase, mapped_column + + +class Base(DeclarativeBase): + "Common base class to signal that models below belong to the same database" + + +# This table isn't included in the schema definition TPP provide for us because it isn't +# created or managed by TPP. Instead we create and populate this table ourselves, +# currently via a command in Cohort Extractor though this may eventually be moved to a +# new repo: +# [1]: https://github.com/opensafely-core/cohort-extractor/blob/dd681275/cohortextractor/update_custom_medication_dictionary.py +class CustomMedicationDictionary(Base): + __tablename__ = "CustomMedicationDictionary" + # Because we don't have write privileges on the main TPP database schema this table + # lives in our "temporary tables" database. To mimic this as closely as possible in + # testing we create it in a separate schema from the other tables. + __table_args__ = {"schema": "temp_tables.dbo"} + + _pk = mapped_column(t.Integer, primary_key=True) + + DMD_ID = mapped_column(t.VARCHAR(50, collation="Latin1_General_CI_AS")) + MultilexDrug_ID = mapped_column(t.VARCHAR(767, collation="Latin1_General_CI_AS")) +""" + + +def fetch_schema_and_data_dictionary(): + # There's currently no API to get the latest output from a workspace so we use a + # regex to extract output IDs from the workspace's outputs page. + index_page = requests.get(OUTPUTS_INDEX_URL) + url_path = urlparse(index_page.url).path.rstrip("/") + escaped_path = re.escape(url_path) + url_re = re.compile(rf"{escaped_path}/(\d+)/") + ids = url_re.findall(index_page.text) + max_id = max(map(int, ids)) + # Once we have the ID we can fetch the output manifest using the API + outputs_api = f"{SERVER_URL}/api/v2/workspaces/{WORKSPACE_NAME}/snapshots/{max_id}" + outputs = requests.get(outputs_api, headers={"Accept": "application/json"}).json() + # And that gives us the URLs for the files + file_urls = {f["name"]: f["url"] for f in outputs["files"]} + rows_url = urljoin(SERVER_URL, file_urls["output/rows.csv"]) + SCHEMA_CSV.write_text(requests.get(rows_url).text) + data_dictionary_url = urljoin(SERVER_URL, file_urls["output/data_dictionary.csv"]) + DATA_DICTIONARY_CSV.write_text(requests.get(data_dictionary_url).text) + decision_support_ref_url = urljoin( + SERVER_URL, file_urls["output/decision_support_value_reference.csv"] + ) + DECISION_SUPPORT_REF_CSV.write_text(requests.get(decision_support_ref_url).text) + categorical_columns_url = urljoin( + SERVER_URL, file_urls["output/results_categorical_columns.csv"] + ) + CATEGORICAL_COLUMNS_CSV.write_text(requests.get(categorical_columns_url).text) + + +def build_schema(): + lines = [] + for table, columns in read_schema().items(): + lines.extend(["", ""]) + lines.append(f"class {class_name_for_table(table)}(Base):") + lines.append(f" __tablename__ = {table!r}") + lines.append(" _pk = mapped_column(t.Integer, primary_key=True)") + lines.append("") + for column in columns: + attr_name = attr_name_for_column(column["ColumnName"]) + lines.append(f" {attr_name} = {definition_for_column(column)}") + write_schema(lines) + + +def read_schema(): + with SCHEMA_CSV.open(newline="") as f: + schema = list(csv.DictReader(f)) + by_table = {} + for item in schema: + by_table.setdefault(item["TableName"], []).append(item) + # We don't include the schema information table in the schema information because + # a) where would this madness end? + # b) it contains some weird types like `sysname` that we don't want to have to + # worry about. + del by_table["OpenSAFELYSchemaInformation"] + # Sort tables and columns into consistent order + return {name: sort_columns(columns) for name, columns in sorted(by_table.items())} + + +def write_schema(lines): + code = "\n".join([HEADER] + lines) + code = ruff_format(code) + SCHEMA_PYTHON.write_text(code) + + +def ruff_format(code): + process = subprocess.run( + [sys.executable, "-m", "ruff", "format", "-"], + check=True, + text=True, + capture_output=True, + input=code, + ) + return process.stdout + + +def sort_columns(columns): + # Assert column names are unique + assert len({c["ColumnName"] for c in columns}) == len(columns) + # Sort columns lexically except keep `Patient_ID` first + return sorted( + columns, + key=lambda c: (c["ColumnName"] != "Patient_ID", c["ColumnName"]), + ) + + +def class_name_for_table(name): + assert is_valid(name), name + return name + + +def attr_name_for_column(name): + name = name.replace(".", "_") + if name == "class": + name = "class_" + assert is_valid(name), name + return name + + +def definition_for_column(column): + default_value, type_formatter = TYPE_MAP[column["ColumnType"]] + args = [type_formatter(column)] + if column["IsNullable"] == "False": + args.append(f"nullable=False, default={default_value!r}") + else: + assert column["IsNullable"] == "True", f"Bad `IsNullable` value in {column!r}" + # If the name isn't a valid Python attribute then we need to supply it explicitly as + # the first argument + name = column["ColumnName"] + if attr_name_for_column(name) != name: + args.insert(0, repr(name)) + return f"mapped_column({', '.join(args)})" + + +def format_string_type(type_name, column): + length = column["MaxLength"] + collation = column["CollationName"] + return f"{type_name}({length}, collation={collation!r})" + + +def format_binary_type(type_name, column): + length = column["MaxLength"] + return f"{type_name}({length})" + + +def is_valid(name): + return name.isidentifier() and not keyword.iskeyword(name) + + +if __name__ == "__main__": + command = sys.argv[1] if len(sys.argv) > 1 else None + if command == "fetch": + fetch_schema_and_data_dictionary() + elif command == "build": + build_schema() + else: + raise RuntimeError(f"Unknown command: {command}; valid commands: fetch, build")