[e988c2]: / tests / lib / update_tpp_schema.py

Download this file

210 lines (169 with data), 7.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import csv
import keyword
import re
import subprocess
import sys
from pathlib import Path
from urllib.parse import urljoin, urlparse
import requests
SERVER_URL = "https://jobs.opensafely.org"
WORKSPACE_NAME = "tpp-database-schema"
OUTPUTS_INDEX_URL = f"{SERVER_URL}/opensafely-internal/{WORKSPACE_NAME}/outputs/"
SCHEMA_DIR = Path(__file__).parent
SCHEMA_CSV = SCHEMA_DIR / "tpp_schema.csv"
SCHEMA_PYTHON = SCHEMA_DIR / "tpp_schema.py"
DATA_DICTIONARY_CSV = SCHEMA_DIR / "tpp_data_dictionary.csv"
DECISION_SUPPORT_REF_CSV = SCHEMA_DIR / "tpp_decision_support_reference.csv"
CATEGORICAL_COLUMNS_CSV = SCHEMA_DIR / "tpp_categorical_columns.csv"
TYPE_MAP = {
"bit": (0, lambda _: "t.Boolean"),
"tinyint": (0, lambda _: "t.SMALLINT"),
"int": (0, lambda _: "t.Integer"),
"bigint": (0, lambda _: "t.BIGINT"),
"numeric": (0, lambda _: "t.Numeric"),
"float": (0.0, lambda _: "t.Float"),
"real": (0.0, lambda _: "t.REAL"),
"date": ("9999-12-31", lambda _: "t.Date"),
"time": ("00:00:00", lambda _: "t.Time"),
"datetime": ("9999-12-31T00:00:00", lambda _: "t.DateTime"),
"char": ("", lambda col: format_string_type("t.CHAR", col)),
"varchar": ("", lambda col: format_string_type("t.VARCHAR", col)),
"varbinary": (b"", lambda col: format_binary_type("t.VARBINARY", col)),
}
HEADER = """\
# This file is auto-generated: DO NOT EDIT IT
#
# To rebuild run:
#
# python tests/lib/update_tpp_schema.py build
#
from sqlalchemy import types as t
from sqlalchemy.orm import DeclarativeBase, mapped_column
class Base(DeclarativeBase):
"Common base class to signal that models below belong to the same database"
# This table isn't included in the schema definition TPP provide for us because it isn't
# created or managed by TPP. Instead we create and populate this table ourselves,
# currently via a command in Cohort Extractor though this may eventually be moved to a
# new repo:
# [1]: https://github.com/opensafely-core/cohort-extractor/blob/dd681275/cohortextractor/update_custom_medication_dictionary.py
class CustomMedicationDictionary(Base):
__tablename__ = "CustomMedicationDictionary"
# Because we don't have write privileges on the main TPP database schema this table
# lives in our "temporary tables" database. To mimic this as closely as possible in
# testing we create it in a separate schema from the other tables.
__table_args__ = {"schema": "temp_tables.dbo"}
_pk = mapped_column(t.Integer, primary_key=True)
DMD_ID = mapped_column(t.VARCHAR(50, collation="Latin1_General_CI_AS"))
MultilexDrug_ID = mapped_column(t.VARCHAR(767, collation="Latin1_General_CI_AS"))
"""
def fetch_schema_and_data_dictionary():
# There's currently no API to get the latest output from a workspace so we use a
# regex to extract output IDs from the workspace's outputs page.
index_page = requests.get(OUTPUTS_INDEX_URL)
url_path = urlparse(index_page.url).path.rstrip("/")
escaped_path = re.escape(url_path)
url_re = re.compile(rf"{escaped_path}/(\d+)/")
ids = url_re.findall(index_page.text)
max_id = max(map(int, ids))
# Once we have the ID we can fetch the output manifest using the API
outputs_api = f"{SERVER_URL}/api/v2/workspaces/{WORKSPACE_NAME}/snapshots/{max_id}"
outputs = requests.get(outputs_api, headers={"Accept": "application/json"}).json()
# And that gives us the URLs for the files
file_urls = {f["name"]: f["url"] for f in outputs["files"]}
rows_url = urljoin(SERVER_URL, file_urls["output/rows.csv"])
SCHEMA_CSV.write_text(requests.get(rows_url).text)
data_dictionary_url = urljoin(SERVER_URL, file_urls["output/data_dictionary.csv"])
DATA_DICTIONARY_CSV.write_text(requests.get(data_dictionary_url).text)
decision_support_ref_url = urljoin(
SERVER_URL, file_urls["output/decision_support_value_reference.csv"]
)
DECISION_SUPPORT_REF_CSV.write_text(requests.get(decision_support_ref_url).text)
categorical_columns_url = urljoin(
SERVER_URL, file_urls["output/results_categorical_columns.csv"]
)
CATEGORICAL_COLUMNS_CSV.write_text(requests.get(categorical_columns_url).text)
def build_schema():
lines = []
for table, columns in read_schema().items():
lines.extend(["", ""])
lines.append(f"class {class_name_for_table(table)}(Base):")
lines.append(f" __tablename__ = {table!r}")
lines.append(" _pk = mapped_column(t.Integer, primary_key=True)")
lines.append("")
for column in columns:
attr_name = attr_name_for_column(column["ColumnName"])
lines.append(f" {attr_name} = {definition_for_column(column)}")
write_schema(lines)
def read_schema():
with SCHEMA_CSV.open(newline="") as f:
schema = list(csv.DictReader(f))
by_table = {}
for item in schema:
by_table.setdefault(item["TableName"], []).append(item)
# We don't include the schema information table in the schema information because
# a) where would this madness end?
# b) it contains some weird types like `sysname` that we don't want to have to
# worry about.
del by_table["OpenSAFELYSchemaInformation"]
# Sort tables and columns into consistent order
return {name: sort_columns(columns) for name, columns in sorted(by_table.items())}
def write_schema(lines):
code = "\n".join([HEADER] + lines)
code = ruff_format(code)
SCHEMA_PYTHON.write_text(code)
def ruff_format(code):
process = subprocess.run(
[sys.executable, "-m", "ruff", "format", "-"],
check=True,
text=True,
capture_output=True,
input=code,
)
return process.stdout
def sort_columns(columns):
# Assert column names are unique
assert len({c["ColumnName"] for c in columns}) == len(columns)
# Sort columns lexically except keep `Patient_ID` first
return sorted(
columns,
key=lambda c: (c["ColumnName"] != "Patient_ID", c["ColumnName"]),
)
def class_name_for_table(name):
assert is_valid(name), name
return name
def attr_name_for_column(name):
name = name.replace(".", "_")
if name == "class":
name = "class_"
assert is_valid(name), name
return name
def definition_for_column(column):
default_value, type_formatter = TYPE_MAP[column["ColumnType"]]
args = [type_formatter(column)]
if column["IsNullable"] == "False":
args.append(f"nullable=False, default={default_value!r}")
else:
assert column["IsNullable"] == "True", f"Bad `IsNullable` value in {column!r}"
# If the name isn't a valid Python attribute then we need to supply it explicitly as
# the first argument
name = column["ColumnName"]
if attr_name_for_column(name) != name:
args.insert(0, repr(name))
return f"mapped_column({', '.join(args)})"
def format_string_type(type_name, column):
length = column["MaxLength"]
collation = column["CollationName"]
return f"{type_name}({length}, collation={collation!r})"
def format_binary_type(type_name, column):
length = column["MaxLength"]
return f"{type_name}({length})"
def is_valid(name):
return name.isidentifier() and not keyword.iskeyword(name)
if __name__ == "__main__":
command = sys.argv[1] if len(sys.argv) > 1 else None
if command == "fetch":
fetch_schema_and_data_dictionary()
elif command == "build":
build_schema()
else:
raise RuntimeError(f"Unknown command: {command}; valid commands: fetch, build")