Switch to side-by-side view

--- a
+++ b/tests/generative/variable_strategies.py
@@ -0,0 +1,663 @@
+import datetime
+from os import environ
+
+import hypothesis as hyp
+import hypothesis.strategies as st
+from hypothesis.control import current_build_context
+
+from ehrql.query_model.nodes import (
+    AggregateByPatient,
+    Case,
+    Dataset,
+    Filter,
+    Function,
+    InlinePatientTable,
+    PickOneRowPerPatient,
+    Position,
+    SelectColumn,
+    SelectPatientTable,
+    SelectTable,
+    SeriesCollectionFrame,
+    Sort,
+    Value,
+)
+from ehrql.query_model.population_validation import (
+    ValidationError,
+    validate_population_definition,
+)
+
+from .generic_strategies import usually
+from .ignored_errors import get_ignored_error_type
+
+
+# Max depth
+#
+# There are various points at which we generate deeply recursive data
+# which hits Hypothesis's recursion limits, and we need to stop going deeper
+# at this point and force generating a terminating node.
+#
+# Otherwise, the generated graph can continue forever, and will eventually hit the
+# hypothesis limit (100) and will be abandoned. This results in too many invalid examples,
+# which triggers the too-many-filters healthcheck.
+#
+# If the max limit is set high - e.g. if we always let it go to 100 and then return our
+# default terminating node, generating the examples takes a really long time.  Setting it
+# too low means that hypothesis takes too long to shrink examples.
+#
+# The default is therefore set, somewhat arbitrarily, to 15.
+
+MAX_DEPTH = int(environ.get("GENTEST_MAX_DEPTH", 15))
+
+
+def depth_exceeded():
+    ctx = current_build_context()
+    return ctx.data.depth > MAX_DEPTH
+
+
+@st.composite
+def _should_stop(draw):
+    """Returns True if we need to stop and generate a terminating node."""
+
+    # Generally speaking we want this to return False unless it needs
+    # to return True. This need can either come from the fact that
+    # we've exceeded the maximum depth, or because the shrinker told
+    # us to.
+    #
+    # In the former case, we still need to draw a variable that says
+    # we should, because this gives us the shrinker the opportunity to
+    # set that decision to false, which makes us no longer dependent on
+    # hitting the maximum depth to generate a terminating node here.
+
+    should_continue = draw(usually)
+
+    if depth_exceeded():
+        should_continue = False
+
+    return not should_continue
+
+
+should_stop = _should_stop()
+
+
+@st.composite
+def depth_bounded_one_of(draw, *options):
+    """Equivalent to `one_of` but if we've got too deep always uses the first option."""
+    assert options
+
+    # Similar to how `should_stop` works, we always draw the choice, but if
+    # we've exceeded the current maximum depth, we pretend that we got a zero
+    # even if we didn't. When the shrinker runs it will change this to zero
+    # for real, and then we no longer need to hit maximum depth for this branch
+    # to trigger.
+    i = draw(st.integers(0, len(options) - 1))
+    if depth_exceeded():
+        i = 0
+    return draw(options[i])
+
+
+# This module defines a set of recursive Hypothesis strategies for generating query model graphs.
+#
+# There are a few points where we deliberate order the types that we choose from, with the
+# "simplest" first (by some subjective measure). This is to enable Hypothesis to more effectively
+# explore the query space and to "shrink" examples when it finds errors. These points are commented
+# below.
+#
+# We use several Hypothesis combinators for defining our strategies. Most (`one_of`, `just`,
+# `sampled_from`) are fairly self-explanatory. A couple are worth clarifying.
+#     * `st.builds()` is used to construct objects, it takes the class and strategies
+#       corresponding to the constructor arguments.
+#     * `@st.composite` allows us to define a strategy by composing other strategies with
+#       arbitrary Python code; it adds a `draw` argument which is part of the machinery that
+#       enables this composition but which doesn't form part of the signature of the resulting
+#       strategy function.
+
+
+def dataset(patient_tables, event_tables, schema, value_strategies):
+    # Every inner-function here returns a Hypothesis strategy for creating the thing it is named
+    # for, not the thing itself.
+    #
+    # Several of these strategy functions ignore one or more of their arguments in order to make
+    # them uniform with other functions that return the same sort of strategy. Such ignored
+    # arguments are named with a leading underscore.
+
+    # Series strategies
+    #
+    # Whenever a series is needed, we call series() passing the type of the series and frame that
+    # it should be built on (these are either constrained by the context in which the series is to
+    # be used or chosen arbitrarily by the caller).
+    #
+    # This strategy then chooses an arbitrary concrete series that respects the constraints imposed
+    # by the passed type and frame.
+    #
+    # A note on frames and domains:
+    #
+    #     When we pass `frame` as an argument to a series strategy function, the intended semantics
+    #     are always "construct a series that is _consistent_ with this frame". It's always
+    #     permitted to return a one-row-per-patient series, because such series can always be
+    #     composed a many-rows-per-patient series; so there are series strategy functions that,
+    #     always or sometimes, ignore the frame argument.
+
+    COMPARABLE_TYPES = [t for t in value_strategies.keys() if t is not bool]
+
+    @st.composite
+    def series(draw, type_, frame):
+        if draw(should_stop):  # pragma: no cover
+            return draw(select_column(type_, frame))
+
+        class DomainConstraint:
+            PATIENT = (True,)
+            NON_PATIENT = (False,)
+            ANY = (True, False)
+
+        # Order matters: "simpler" first (see header comment)
+        series_constraints = {
+            select_column: (value_strategies.keys(), DomainConstraint.ANY),
+            exists: ({bool}, DomainConstraint.PATIENT),
+            count: ({int}, DomainConstraint.PATIENT),
+            count_distinct: ({int}, DomainConstraint.PATIENT),
+            min_: (COMPARABLE_TYPES, DomainConstraint.PATIENT),
+            max_: (COMPARABLE_TYPES, DomainConstraint.PATIENT),
+            sum_: ({int, float}, DomainConstraint.PATIENT),
+            mean: ({float}, DomainConstraint.PATIENT),
+            is_null: ({bool}, DomainConstraint.ANY),
+            not_: ({bool}, DomainConstraint.ANY),
+            year_from_date: ({int}, DomainConstraint.ANY),
+            month_from_date: ({int}, DomainConstraint.ANY),
+            day_from_date: ({int}, DomainConstraint.ANY),
+            to_first_of_year: ({datetime.date}, DomainConstraint.ANY),
+            to_first_of_month: ({datetime.date}, DomainConstraint.ANY),
+            cast_to_float: ({float}, DomainConstraint.ANY),
+            cast_to_int: ({int}, DomainConstraint.ANY),
+            negate: ({int, float}, DomainConstraint.ANY),
+            eq: ({bool}, DomainConstraint.ANY),
+            ne: ({bool}, DomainConstraint.ANY),
+            string_contains: ({bool}, DomainConstraint.ANY),
+            in_: ({bool}, DomainConstraint.ANY),
+            and_: ({bool}, DomainConstraint.ANY),
+            or_: ({bool}, DomainConstraint.ANY),
+            lt: ({bool}, DomainConstraint.ANY),
+            gt: ({bool}, DomainConstraint.ANY),
+            le: ({bool}, DomainConstraint.ANY),
+            ge: ({bool}, DomainConstraint.ANY),
+            add: ({int, float}, DomainConstraint.ANY),
+            subtract: ({int, float}, DomainConstraint.ANY),
+            multiply: ({int, float}, DomainConstraint.ANY),
+            truediv: ({float}, DomainConstraint.ANY),
+            floordiv: ({int}, DomainConstraint.ANY),
+            date_add_years: ({datetime.date}, DomainConstraint.ANY),
+            date_add_months: ({datetime.date}, DomainConstraint.ANY),
+            date_add_days: ({datetime.date}, DomainConstraint.ANY),
+            date_difference_in_years: ({int}, DomainConstraint.ANY),
+            date_difference_in_months: ({int}, DomainConstraint.ANY),
+            date_difference_in_days: ({int}, DomainConstraint.ANY),
+            count_episodes: ({int}, DomainConstraint.PATIENT),
+            case: ({int, float, bool, datetime.date}, DomainConstraint.ANY),
+            maximum_of: (COMPARABLE_TYPES, DomainConstraint.ANY),
+            minimum_of: (COMPARABLE_TYPES, DomainConstraint.ANY),
+        }
+        series_types = series_constraints.keys()
+
+        def constraints_match(series_type):
+            type_constraint, domain_constraint = series_constraints[series_type]
+            return (
+                type_ in type_constraint
+                and is_one_row_per_patient_frame(frame) in domain_constraint
+            )
+
+        possible_series = [s for s in series_types if constraints_match(s)]
+        assert possible_series, f"No series matches {type_}, {type(frame)}"
+
+        series_strategy = draw(st.sampled_from(possible_series))
+        return draw(series_strategy(type_, frame))
+
+    def value(type_, _frame):
+        return st.builds(Value, value_strategies[type_])
+
+    def select_column(type_, frame):
+        column_names = [n for n, t in schema.column_types if t == type_]
+        return st.builds(SelectColumn, st.just(frame), st.sampled_from(column_names))
+
+    def exists(_type, _frame):
+        return st.builds(AggregateByPatient.Exists, any_frame())
+
+    def count(_type, _frame):
+        return st.builds(AggregateByPatient.Count, any_frame())
+
+    @st.composite
+    def count_distinct(draw, _type, _frame):
+        type_ = draw(any_type())
+        frame = draw(many_rows_per_patient_frame())
+        return AggregateByPatient.CountDistinct(draw(series(type_, frame)))
+
+    @st.composite
+    def count_episodes(draw, _type, _frame):
+        frame = draw(many_rows_per_patient_frame())
+        date_series = draw(series(datetime.date, frame))
+        maximum_gap_days = draw(st.integers(1, 5))
+        return AggregateByPatient.CountEpisodes(date_series, maximum_gap_days)
+
+    def min_(type_, _frame):
+        return aggregation_operation(type_, AggregateByPatient.Min)
+
+    def max_(type_, _frame):
+        return aggregation_operation(type_, AggregateByPatient.Max)
+
+    def sum_(type_, _frame):
+        return aggregation_operation(type_, AggregateByPatient.Sum)
+
+    def combine_as_set(type_, _frame):
+        return aggregation_operation(type_, AggregateByPatient.CombineAsSet)
+
+    @st.composite
+    def mean(draw, _type, _frame):
+        type_ = draw(any_numeric_type())
+        frame = draw(many_rows_per_patient_frame())
+        return AggregateByPatient.Mean(draw(series(type_, frame)))
+
+    @st.composite
+    def aggregation_operation(draw, type_, aggregation):
+        # An aggregation operation that returns a patient series but takes a
+        # series drawn from a many-rows-per-patient frame
+        frame = draw(many_rows_per_patient_frame())
+        return aggregation(draw(series(type_, frame)))
+
+    @st.composite
+    def is_null(draw, _type, frame):
+        type_ = draw(any_type())
+        return Function.IsNull(draw(series(type_, frame)))
+
+    def not_(type_, frame):
+        return st.builds(Function.Not, series(type_, frame))
+
+    def year_from_date(_type, frame):
+        return st.builds(Function.YearFromDate, series(datetime.date, frame))
+
+    def month_from_date(_type, frame):
+        return st.builds(Function.MonthFromDate, series(datetime.date, frame))
+
+    def day_from_date(_type, frame):
+        return st.builds(Function.DayFromDate, series(datetime.date, frame))
+
+    def to_first_of_year(_type, frame):
+        return st.builds(Function.ToFirstOfYear, series(datetime.date, frame))
+
+    def to_first_of_month(_type, frame):
+        return st.builds(Function.ToFirstOfMonth, series(datetime.date, frame))
+
+    @st.composite
+    def cast_to_float(draw, _type, frame):
+        type_ = draw(any_numeric_type())
+        return Function.CastToFloat(draw(series(type_, frame)))
+
+    @st.composite
+    def cast_to_int(draw, type_, frame):
+        type_ = draw(any_numeric_type())
+        return Function.CastToInt(draw(series(type_, frame)))
+
+    def negate(type_, frame):
+        return st.builds(Function.Negate, series(type_, frame))
+
+    @st.composite
+    def eq(draw, _type, frame):
+        type_ = draw(any_type())
+        return draw(binary_operation(type_, frame, Function.EQ))
+
+    @st.composite
+    def ne(draw, _type, frame):
+        type_ = draw(any_type())
+        return draw(binary_operation(type_, frame, Function.NE))
+
+    def string_contains(_type, frame):
+        return binary_operation(str, frame, Function.StringContains)
+
+    @st.composite
+    def in_(draw, _type, frame):
+        type_ = draw(any_type())
+        if not draw(st.booleans()):
+            rhs = Value(
+                frozenset(
+                    draw(st.sets(value_strategies[type_], min_size=0, max_size=5))
+                )
+            )
+        else:
+            rhs = draw(combine_as_set(type_, frame))
+        return Function.In(draw(series(type_, frame)), rhs)
+
+    def and_(type_, frame):
+        return binary_operation(type_, frame, Function.And, allow_value=False)
+
+    def or_(type_, frame):
+        return binary_operation(type_, frame, Function.Or, allow_value=False)
+
+    @st.composite
+    def lt(draw, _type, frame):
+        type_ = draw(any_comparable_type())
+        return draw(binary_operation(type_, frame, Function.LT))
+
+    @st.composite
+    def gt(draw, _type, frame):
+        type_ = draw(any_comparable_type())
+        return draw(binary_operation(type_, frame, Function.GT))
+
+    @st.composite
+    def le(draw, _type, frame):
+        type_ = draw(any_comparable_type())
+        return draw(binary_operation(type_, frame, Function.LE))
+
+    @st.composite
+    def ge(draw, _type, frame):
+        type_ = draw(any_comparable_type())
+        return draw(binary_operation(type_, frame, Function.GE))
+
+    def add(type_, frame):
+        return binary_operation(type_, frame, Function.Add)
+
+    def subtract(type_, frame):
+        return binary_operation(type_, frame, Function.Subtract)
+
+    def multiply(type_, frame):
+        return binary_operation(type_, frame, Function.Multiply)
+
+    def truediv(type_, frame):
+        return binary_operation(type_, frame, Function.TrueDivide)
+
+    def floordiv(type_, frame):
+        return binary_operation(type_, frame, Function.FloorDivide)
+
+    def date_add_years(type_, frame):
+        return binary_operation_with_types(type_, int, frame, Function.DateAddYears)
+
+    def date_add_months(type_, frame):
+        return binary_operation_with_types(type_, int, frame, Function.DateAddMonths)
+
+    def date_add_days(type_, frame):
+        return binary_operation_with_types(type_, int, frame, Function.DateAddDays)
+
+    def date_difference_in_years(type_, frame):
+        return binary_operation(datetime.date, frame, Function.DateDifferenceInYears)
+
+    def date_difference_in_months(type_, frame):
+        return binary_operation(datetime.date, frame, Function.DateDifferenceInMonths)
+
+    def date_difference_in_days(type_, frame):
+        return binary_operation(datetime.date, frame, Function.DateDifferenceInDays)
+
+    @st.composite
+    def case(draw, type_, frame):
+        # case takes a mapping argument which is a dict where:
+        #   - keys are a bool series
+        #   - values are either a series or Value of `type_` or None
+        # It also takes a default, which can be None or a Value or series of `type_`
+        key_st = series(bool, frame)
+        value_st = st.one_of(st.none(), value(type_, frame), series(type_, frame))
+        mapping_st = st.dictionaries(key_st, value_st, min_size=1, max_size=3)
+        default_st = st.one_of(st.none(), value(type_, frame), series(type_, frame))
+        mapping = draw(mapping_st)
+        default = draw(default_st)
+        # A valid Case needs at least one non-NULL value or a default
+        hyp.assume(not all(v is None for v in [default, *mapping.values()]))
+        return Case(mapping, default)
+
+    def binary_operation(type_, frame, operator_func, allow_value=True):
+        # A strategy for operations that take lhs and rhs arguments of the
+        # same type
+        return binary_operation_with_types(
+            type_, type_, frame, operator_func, allow_value=allow_value
+        )
+
+    @st.composite
+    def binary_operation_with_types(
+        draw, lhs_type, rhs_type, frame, operator_func, allow_value=True
+    ):
+        # A strategy for operations that take lhs and rhs arguments with specified lhs
+        # and rhs types (which may be different)
+
+        # A binary operation has 2 inputs, which are
+        # 1) A series drawn from the specified frame
+        # 2) one of:
+        #    a) A series drawn from the specified frame
+        #    b) A series drawn from any one-row-per-patient-frame
+        #    c) A series that is a Value
+        #       For certain operations, Value is not allowed;  Specifically, for boolean operations
+        #       i.e. and/or which take two boolean series as inputs, we exclude operations that would
+        #       use True/False constant Values.  These are unlikely to be seen in the wild, and cause
+        #       particularly nonsensical Case statements in generative test examples.
+
+        # first pick an "other" input series (i.e. #2 above), either a value series (if allowed)
+        # or a series drawn from a frame
+        series_options = [value, series] if allow_value else [series]
+        other_series = draw(st.sampled_from(series_options))
+        # Now pick a frame for the series to be drawn from
+        # The other frame will either be a new one-row-per-patient-frame or this frame
+        # (Note if the other_series is a value, the frame will be ignored)
+        other_frame = draw(st.one_of(one_row_per_patient_frame(), st.just(frame)))
+
+        # Pick the order of the lhs and rhs inputs built from the two frames and
+        # associated strategies
+        lhs_frame, lhs_input, rhs_frame, rhs_input = draw(
+            st.sampled_from(
+                [
+                    (frame, series, other_frame, other_series),
+                    (other_frame, other_series, frame, series),
+                ]
+            )
+        )
+        lhs = draw(lhs_input(lhs_type, lhs_frame))
+        rhs = draw(rhs_input(rhs_type, rhs_frame))
+
+        return operator_func(lhs, rhs)
+
+    @st.composite
+    def nary_operation_with_types(draw, frame, operator_func, series_type):
+        # A strategy for operations that take _n_ arguments which are expected to be
+        # the same type
+
+        # Decide how many arguments we want – we're intending to test the logic of the
+        # query engines, not their scaling properties so we don't need too many
+        num_args = draw(st.integers(1, 4))
+        # Pick out some arguments (identified by index) to be drawn from other frames
+        other_frame_args = draw(
+            st.lists(
+                # Draw a list of argument indices
+                st.integers(0, num_args - 1),
+                # Always leaving at least one argument to be drawn from the original
+                # frame
+                max_size=num_args - 1,
+                unique=True,
+            )
+        )
+        args = []
+        # Clauses below arranged in order of simplicity (as Hypothesis sees it)
+        for i in range(num_args):
+            if i not in other_frame_args:
+                arg = draw(series(series_type, frame))
+            else:
+                # If it's not drawn from the supplied frame then it should be either a
+                # value or a one-row-per-patient series
+                if not draw(st.booleans()):
+                    arg = draw(value(series_type, None))
+                else:
+                    arg = draw(series(series_type, draw(one_row_per_patient_frame())))
+            args.append(arg)
+        return operator_func(tuple(args))
+
+    def maximum_of(type_, frame):
+        return nary_operation_with_types(frame, Function.MaximumOf, type_)
+
+    def minimum_of(type_, frame):
+        return nary_operation_with_types(frame, Function.MinimumOf, type_)
+
+    def any_type():
+        return st.sampled_from(list(value_strategies.keys()))
+
+    def any_numeric_type():
+        return st.sampled_from([int, float])
+
+    def any_comparable_type():
+        return st.sampled_from(COMPARABLE_TYPES)
+
+    # Frame strategies
+    #
+    # The main concern when choosing a frame is whether it has one or many rows per patient. Some
+    # callers require one or the other, some don't mind; so we provide strategies for each case.
+    # And sometimes callers need _either_ the frame they have in their hand _or_ an arbitrary
+    # patient frame, so we provide a strategy for that too.
+    #
+    # At variance with the general approach here, many-rows-per-patient frames are created by
+    # imperatively building stacks of filters on top of select nodes, rather than relying on
+    # recursion, because it enormously simplifies the logic needed to keep filter conditions
+    # consistent with the source.
+    def any_frame():
+        # Order matters: "simpler" first (see header comment)
+        return st.one_of(
+            one_row_per_patient_frame(),
+            many_rows_per_patient_frame(),
+        )
+
+    def one_row_per_patient_frame():
+        return depth_bounded_one_of(
+            select_patient_table(),
+            pick_one_row_per_patient_frame(),
+            inline_patient_table(),
+        )
+
+    def many_rows_per_patient_frame():
+        return depth_bounded_one_of(select_table(), filtered_table())
+
+    @st.composite
+    def filtered_table(draw):
+        source = draw(select_table())
+        for _ in range(draw(st.integers(min_value=1, max_value=6))):
+            source = draw(filter_(source))
+        return source
+
+    @st.composite
+    def sorted_frame(draw):
+        # Decide how many Sorts and Filters (if any) we're going to apply
+        operations = draw(
+            st.lists(st.sampled_from([sort, filter_]), min_size=1, max_size=9).filter(
+                lambda ls: (1 <= ls.count(sort) <= 3) and (ls.count(filter_) <= 6)
+            )
+        )
+        # Pick a table and apply the operations
+        source = draw(select_table())
+        for operation in operations:
+            source = draw(operation(source))
+        return source
+
+    @st.composite
+    def pick_one_row_per_patient_frame(draw):
+        source = draw(sorted_frame())
+        sort_order = draw(st.sampled_from([Position.FIRST, Position.LAST]))
+        return PickOneRowPerPatient(source, sort_order)
+
+    def select_table():
+        return st.builds(SelectTable, st.sampled_from(event_tables), st.just(schema))
+
+    def select_patient_table():
+        return st.builds(
+            SelectPatientTable, st.sampled_from(patient_tables), st.just(schema)
+        )
+
+    @st.composite
+    def inline_patient_table(draw):
+        return InlinePatientTable(
+            rows=tuple(
+                draw(
+                    st.lists(
+                        st.tuples(
+                            st.integers(1, 10),
+                            *[
+                                value_strategies[type_]
+                                for name, type_ in schema.column_types
+                            ],
+                        ),
+                        unique_by=lambda r: r[0],
+                    ),
+                )
+            ),
+            schema=schema,
+        )
+
+    @st.composite
+    def filter_(draw, source):
+        condition = draw(series(bool, draw(ancestor_of(source))))
+        return Filter(source, condition)
+
+    @st.composite
+    def sort(draw, source):
+        type_ = draw(any_comparable_type())
+        sort_by = draw(series(type_, draw(ancestor_of(source))))
+        return Sort(source, sort_by)
+
+    @st.composite
+    def ancestor_of(draw, frame):
+        for _ in range(draw(st.integers(min_value=0, max_value=3))):
+            if hasattr(frame, "source"):
+                frame = frame.source
+            else:
+                break
+        return frame
+
+    # Variable strategy
+    #
+    # Puts everything above together to create a variable.
+    @st.composite
+    def valid_patient_variable(draw):
+        type_ = draw(any_type())
+        frame = draw(one_row_per_patient_frame())
+        return draw(series(type_, frame))
+
+    @st.composite
+    def valid_event_series(draw):
+        type_ = draw(any_type())
+        frame = draw(many_rows_per_patient_frame())
+        return draw(series(type_, frame))
+
+    # A population definition is a boolean-typed variable that meets some additional
+    # criteria enforced by the query model
+    @st.composite
+    def valid_population(draw):
+        frame = draw(one_row_per_patient_frame())
+        population = draw(series(bool, frame))
+        hyp.assume(is_valid_population(population))
+        return population
+
+    return st.builds(
+        make_dataset,
+        valid_population(),
+        valid_patient_variable(),
+        # Event series is optional
+        st.one_of(st.none(), valid_event_series()),
+    )
+
+
+def make_dataset(population, patient_variable, event_series):
+    return Dataset(
+        population=population,
+        variables={"v": patient_variable},
+        events=(
+            {
+                "event_table": SeriesCollectionFrame({"e": event_series}),
+            }
+            if event_series is not None
+            else {}
+        ),
+        measures=None,
+    )
+
+
+def is_valid_population(series):
+    try:
+        validate_population_definition(series)
+        return True
+    except ValidationError:
+        return False
+    except Exception as e:  # pragma: no cover
+        if get_ignored_error_type(e):
+            return False
+        raise
+
+
+def is_one_row_per_patient_frame(frame):
+    return isinstance(frame, SelectPatientTable | PickOneRowPerPatient)