Switch to side-by-side view

--- a
+++ b/tests/lib/gentest_example_simplify.py
@@ -0,0 +1,240 @@
+"""
+Attempts to de-duplicate query model structures by extracting repeated elements into
+variables.
+
+Usage looks like:
+
+    * Copy the query model example (the `population`, `variable` and `data` arguments)
+      into a file. Just copy the arguments as-is: don't worry about indendation,
+      trailing commas or missing imports.
+
+    * Run `python -m tests.lib.gentest_example_simplify PATH_TO_FILE`.
+
+    * If the output looks vaguely sensible run the command again with the `--inplace`
+      option to update the original file.
+
+    * Table and column definitions should be automatically extracted, but other kinds of
+      repeated structure will need to be extracted by hand. To do this: copy the
+      structure, assign it to a variable, and re-run the above command.
+"""
+
+import argparse
+import ast
+import dataclasses
+import pathlib
+import re
+import subprocess
+import sys
+import typing
+from collections import defaultdict
+from functools import singledispatchmethod
+
+import ehrql.query_model.nodes
+from ehrql.query_model.nodes import (
+    InlinePatientTable,
+    Node,
+    SelectColumn,
+    SelectPatientTable,
+    SelectTable,
+)
+
+
+TABLE_TYPES = SelectTable | SelectPatientTable | InlinePatientTable
+
+
+VARIABLE_NAMES = ["dataset", "data"]
+
+
+def main(filename, output=False):  # pragma: no cover
+    contents = filename.read_text()
+    code = simplify(contents)
+    if not output:
+        filename.write_text(code)
+        return ""
+    else:
+        return code
+
+
+def simplify(contents):
+    contents = fix_up_module(contents)
+    namespace = {}
+    exec(contents, namespace)
+    variables = {
+        name: fix_accidental_tuple(namespace.pop(name))
+        for name in VARIABLE_NAMES
+        if name in namespace
+    }
+    qm_repr = QueryModelRepr(namespace)
+    variable_reprs = {name: qm_repr(value) for name, value in variables.items()}
+    output = [get_imports(contents)]
+    output.extend(
+        [
+            f"{name} = {reference_repr}"
+            for name, reference_repr in qm_repr.reference_reprs.items()
+        ]
+    )
+    for name, variable_repr in variable_reprs.items():
+        output.append(f"{name} = {variable_repr}")
+    code = "\n\n".join(output)
+    code = ruff_format(code)
+    return code
+
+
+def ruff_format(code):
+    process = subprocess.run(
+        [sys.executable, "-m", "ruff", "format", "-"],
+        check=True,
+        text=True,
+        capture_output=True,
+        input=code,
+    )
+    return process.stdout
+
+
+def fix_accidental_tuple(value):
+    # Fix values where copying the Hypothesis example has left a trailing comma
+    # resulting in an accidental tuple. The specific values we apply this to are never
+    # intended to be tuples so we don't need to worry about false positives.
+    if isinstance(value, tuple) and len(value) == 1:
+        return value[0]
+    return value
+
+
+def fix_up_module(contents):
+    "Apply some basic fixes to the module to make it importable"
+    # If it has imports we assume it's been fixed up already
+    if re.search(r"\bimport\b", contents):
+        return contents
+    names = "|".join(map(re.escape, VARIABLE_NAMES))
+    # Strip leading indentation
+    contents = re.sub(rf"^\s+({names})\s*=\s*", r"\1 = ", contents, flags=re.MULTILINE)
+    # Add imports (many of these will be unnecessary but that's fine)
+    imports = [
+        "import datetime",
+        "from tests.generative.test_query_model import data_setup, schema",
+        f"from ehrql.query_model.nodes import ({', '.join(ehrql.query_model.nodes.__all__)})",
+    ]
+    contents = "\n".join(imports) + "\n" + contents
+    return contents
+
+
+class QueryModelRepr:
+    def __init__(self, namespace):
+        # Create an inverse mapping which maps each (hashable) value in the namespace to
+        # the first name to which it's bound
+        self.valuespace = {}
+        for key, value in namespace.items():
+            if not key.startswith("__") and isinstance(value, typing.Hashable):
+                self.valuespace.setdefault(value, key)
+        # Dict to record the repr of every value we use in `valuespace`
+        self.reference_reprs = {}
+        self.inline_table_number = defaultdict(iter(range(2**32)).__next__)
+
+    def __call__(self, value):
+        return self.repr(value)
+
+    def repr(self, value):  # noqa: A003
+        # If the value is already in the provided namespace then just use its name as
+        # the repr
+        if isinstance(value, typing.Hashable) and value in self.valuespace:
+            name = self.valuespace[value]
+            # Record the original repr of the value being referenced
+            if name not in self.reference_reprs:
+                self.reference_reprs[name] = self.repr_value(value)
+            return name
+        # Automatically create references for table definitions to avoid repeating them
+        elif isinstance(value, TABLE_TYPES):
+            self.valuespace[value] = self.table_name(value)
+            return self.repr(value)
+        # Automatically create references where columns are selected directly from
+        # tables
+        elif isinstance(value, SelectColumn) and isinstance(value.source, TABLE_TYPES):
+            self.valuespace[value] = f"{self.table_name(value.source)}_{value.name}"
+            return self.repr(value)
+        else:
+            return self.repr_value(value)
+
+    def table_name(self, value):
+        if isinstance(value, InlinePatientTable):
+            return f"inline_{self.inline_table_number[value]}"
+        else:
+            return value.name
+
+    @singledispatchmethod
+    def repr_value(self, value):
+        return repr(value)
+
+    @repr_value.register(type)
+    def repr_type(self, value):
+        return f"{value.__module__}.{value.__qualname__}"
+
+    @repr_value.register(Node)
+    def repr_node(self, value):
+        args = []
+        kwargs = {}
+        fields = dataclasses.fields(value)
+        # Single argument nodes use positional arguments for brevity
+        if len(fields) == 1:
+            args = [getattr(value, fields[0].name)]
+        else:
+            kwargs = {field.name: getattr(value, field.name) for field in fields}
+        return self.repr_init(value, args, kwargs)
+
+    @repr_value.register(list)
+    def repr_list(self, value):
+        elements = [self.repr(v) for v in value]
+        return f"[{', '.join(elements)}]"
+
+    @repr_value.register(dict)
+    def repr_dict(self, value):
+        elements = [f"{self.repr(k)}: {self.repr(v)}" for k, v in value.items()]
+        return f"{{{', '.join(elements)}}}"
+
+    @repr_value.register(frozenset)
+    def repr_frozenset(self, value):
+        elements = [self.repr(v) for v in value]
+        return f"frozenset({{{','.join(elements)}}})"
+
+    @repr_value.register(tuple)
+    def repr_tuple(self, value):
+        elements = [self.repr(v) for v in value]
+        if len(elements) == 1:
+            return f"({elements[0]},)"
+        else:
+            return f"({','.join(elements)})"
+
+    def repr_init(self, obj, args, kwargs):
+        all_args = [self.repr(arg) for arg in args]
+        all_args.extend(f"{key}={self.repr(value)}" for key, value in kwargs.items())
+        name = obj.__class__.__qualname__
+        return f"{name}({', '.join(all_args)})"
+
+
+def get_imports(contents):
+    imports = []
+    for element in ast.parse(contents).body:
+        if isinstance(element, ast.Import | ast.ImportFrom):
+            imports.append(ast.unparse(element))
+    return "\n".join(imports)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description=__doc__,
+        formatter_class=argparse.RawTextHelpFormatter,
+    )
+    parser.add_argument(
+        "filename",
+        type=pathlib.Path,
+        help="Path to a Python file containing a generative test example",
+    )
+    parser.add_argument(
+        "-o",
+        "--output",
+        dest="output",
+        action="store_true",
+        help="Write file to stdout instead of updating it in-place",
+    )
+    args = parser.parse_args()
+    output = main(**vars(args))
+    print(output, end="")