[e988c2]: / tests / lib / gentest_example_simplify.py

Download this file

241 lines (202 with data), 8.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
Attempts to de-duplicate query model structures by extracting repeated elements into
variables.
Usage looks like:
* Copy the query model example (the `population`, `variable` and `data` arguments)
into a file. Just copy the arguments as-is: don't worry about indendation,
trailing commas or missing imports.
* Run `python -m tests.lib.gentest_example_simplify PATH_TO_FILE`.
* If the output looks vaguely sensible run the command again with the `--inplace`
option to update the original file.
* Table and column definitions should be automatically extracted, but other kinds of
repeated structure will need to be extracted by hand. To do this: copy the
structure, assign it to a variable, and re-run the above command.
"""
import argparse
import ast
import dataclasses
import pathlib
import re
import subprocess
import sys
import typing
from collections import defaultdict
from functools import singledispatchmethod
import ehrql.query_model.nodes
from ehrql.query_model.nodes import (
InlinePatientTable,
Node,
SelectColumn,
SelectPatientTable,
SelectTable,
)
TABLE_TYPES = SelectTable | SelectPatientTable | InlinePatientTable
VARIABLE_NAMES = ["dataset", "data"]
def main(filename, output=False): # pragma: no cover
contents = filename.read_text()
code = simplify(contents)
if not output:
filename.write_text(code)
return ""
else:
return code
def simplify(contents):
contents = fix_up_module(contents)
namespace = {}
exec(contents, namespace)
variables = {
name: fix_accidental_tuple(namespace.pop(name))
for name in VARIABLE_NAMES
if name in namespace
}
qm_repr = QueryModelRepr(namespace)
variable_reprs = {name: qm_repr(value) for name, value in variables.items()}
output = [get_imports(contents)]
output.extend(
[
f"{name} = {reference_repr}"
for name, reference_repr in qm_repr.reference_reprs.items()
]
)
for name, variable_repr in variable_reprs.items():
output.append(f"{name} = {variable_repr}")
code = "\n\n".join(output)
code = ruff_format(code)
return code
def ruff_format(code):
process = subprocess.run(
[sys.executable, "-m", "ruff", "format", "-"],
check=True,
text=True,
capture_output=True,
input=code,
)
return process.stdout
def fix_accidental_tuple(value):
# Fix values where copying the Hypothesis example has left a trailing comma
# resulting in an accidental tuple. The specific values we apply this to are never
# intended to be tuples so we don't need to worry about false positives.
if isinstance(value, tuple) and len(value) == 1:
return value[0]
return value
def fix_up_module(contents):
"Apply some basic fixes to the module to make it importable"
# If it has imports we assume it's been fixed up already
if re.search(r"\bimport\b", contents):
return contents
names = "|".join(map(re.escape, VARIABLE_NAMES))
# Strip leading indentation
contents = re.sub(rf"^\s+({names})\s*=\s*", r"\1 = ", contents, flags=re.MULTILINE)
# Add imports (many of these will be unnecessary but that's fine)
imports = [
"import datetime",
"from tests.generative.test_query_model import data_setup, schema",
f"from ehrql.query_model.nodes import ({', '.join(ehrql.query_model.nodes.__all__)})",
]
contents = "\n".join(imports) + "\n" + contents
return contents
class QueryModelRepr:
def __init__(self, namespace):
# Create an inverse mapping which maps each (hashable) value in the namespace to
# the first name to which it's bound
self.valuespace = {}
for key, value in namespace.items():
if not key.startswith("__") and isinstance(value, typing.Hashable):
self.valuespace.setdefault(value, key)
# Dict to record the repr of every value we use in `valuespace`
self.reference_reprs = {}
self.inline_table_number = defaultdict(iter(range(2**32)).__next__)
def __call__(self, value):
return self.repr(value)
def repr(self, value): # noqa: A003
# If the value is already in the provided namespace then just use its name as
# the repr
if isinstance(value, typing.Hashable) and value in self.valuespace:
name = self.valuespace[value]
# Record the original repr of the value being referenced
if name not in self.reference_reprs:
self.reference_reprs[name] = self.repr_value(value)
return name
# Automatically create references for table definitions to avoid repeating them
elif isinstance(value, TABLE_TYPES):
self.valuespace[value] = self.table_name(value)
return self.repr(value)
# Automatically create references where columns are selected directly from
# tables
elif isinstance(value, SelectColumn) and isinstance(value.source, TABLE_TYPES):
self.valuespace[value] = f"{self.table_name(value.source)}_{value.name}"
return self.repr(value)
else:
return self.repr_value(value)
def table_name(self, value):
if isinstance(value, InlinePatientTable):
return f"inline_{self.inline_table_number[value]}"
else:
return value.name
@singledispatchmethod
def repr_value(self, value):
return repr(value)
@repr_value.register(type)
def repr_type(self, value):
return f"{value.__module__}.{value.__qualname__}"
@repr_value.register(Node)
def repr_node(self, value):
args = []
kwargs = {}
fields = dataclasses.fields(value)
# Single argument nodes use positional arguments for brevity
if len(fields) == 1:
args = [getattr(value, fields[0].name)]
else:
kwargs = {field.name: getattr(value, field.name) for field in fields}
return self.repr_init(value, args, kwargs)
@repr_value.register(list)
def repr_list(self, value):
elements = [self.repr(v) for v in value]
return f"[{', '.join(elements)}]"
@repr_value.register(dict)
def repr_dict(self, value):
elements = [f"{self.repr(k)}: {self.repr(v)}" for k, v in value.items()]
return f"{{{', '.join(elements)}}}"
@repr_value.register(frozenset)
def repr_frozenset(self, value):
elements = [self.repr(v) for v in value]
return f"frozenset({{{','.join(elements)}}})"
@repr_value.register(tuple)
def repr_tuple(self, value):
elements = [self.repr(v) for v in value]
if len(elements) == 1:
return f"({elements[0]},)"
else:
return f"({','.join(elements)})"
def repr_init(self, obj, args, kwargs):
all_args = [self.repr(arg) for arg in args]
all_args.extend(f"{key}={self.repr(value)}" for key, value in kwargs.items())
name = obj.__class__.__qualname__
return f"{name}({', '.join(all_args)})"
def get_imports(contents):
imports = []
for element in ast.parse(contents).body:
if isinstance(element, ast.Import | ast.ImportFrom):
imports.append(ast.unparse(element))
return "\n".join(imports)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"filename",
type=pathlib.Path,
help="Path to a Python file containing a generative test example",
)
parser.add_argument(
"-o",
"--output",
dest="output",
action="store_true",
help="Write file to stdout instead of updating it in-place",
)
args = parser.parse_args()
output = main(**vars(args))
print(output, end="")