Switch to side-by-side view

--- a
+++ b/ehrql/query_engines/sqlite.py
@@ -0,0 +1,169 @@
+import sqlalchemy
+from sqlalchemy.sql.functions import Function as SQLFunction
+
+from ehrql.query_engines.base_sql import BaseSQLQueryEngine, get_cyclic_coalescence
+from ehrql.query_engines.sqlite_dialect import SQLiteDialect
+from ehrql.utils.itertools_utils import iter_flatten
+from ehrql.utils.math_utils import get_grouping_level_as_int
+from ehrql.utils.sequence_utils import ordered_set
+
+
+class SQLiteQueryEngine(BaseSQLQueryEngine):
+    sqlalchemy_dialect = SQLiteDialect
+
+    def date_difference_in_days(self, end, start):
+        start_day = SQLFunction("JULIANDAY", start)
+        end_day = SQLFunction("JULIANDAY", end)
+        return sqlalchemy.cast(end_day - start_day, sqlalchemy.Integer)
+
+    def get_date_part(self, date, part):
+        format_str = {"YEAR": "%Y", "MONTH": "%m", "DAY": "%d"}[part]
+        part_as_str = SQLFunction("STRFTIME", format_str, date)
+        return sqlalchemy.cast(part_as_str, sqlalchemy.Integer)
+
+    def date_add_days(self, date, num_days):
+        return self.date_add("days", date, num_days)
+
+    def date_add_months(self, date, num_months):
+        new_date = self.date_add("months", date, num_months)
+        # In cases of day-of-month overflow, SQLite *usually* rolls over to the first of
+        # the next month as want it to.
+        # It does this by performing a normalisation when a date arithmetic operation results
+        # in a date with an invalid number of days (e.g. 30 Feb 2000); namely it
+        # rolls over to the next month by incrementing the month and subtracting
+        # the number of days in the month. For all months except February, the invalid day is only
+        # ever at most 1 day off (i.e. the 31st, for a month with only 30 days), and so this
+        # normalisation results in a rollover to the first of the next month. However for February,
+        # the invalid day can be 1 to 3 days off, which means date rollovers can result in the 1st,
+        # 2nd or 3rd March.
+        #
+        # The SQLite docs (https://sqlite.org/lang_datefunc.html) state:
+        # "Note that "±NNN months" works by rendering the original date into the YYYY-MM-DD
+        # format, adding the ±NNN to the MM month value, then normalizing the result. Thus,
+        # for example, the date 2001-03-31 modified by '+1 month' initially yields 2001-04-31,
+        # but April only has 30 days so the date is normalized to 2001-05-01"
+        #
+        # i.e. 2001-03-31 +1 month results in 2001-04-31, but as the number of days is invalid,
+        # SQLite rolls over to the next month by incrementing the month by 1, and subtracting
+        # the number of days in April (30) from the days, resulting in 2001-05-01
+        #
+        # In the case of February, a calculation that results in a date with an invalid number of
+        # days follows the same normalisation method:
+        # 2000-02-30: increment month by 1, subtract 29 (leap year) days -> 2000-03-01
+        # 2001-02-30: increment month by 1, subtract 28 days -> 2000-03-02
+        # 2000-02-31: increment month by 1, subtract 29 (leap year) days -> 2000-03-02
+        # 2001-02-31: increment month by 1, subtract 28 days -> 2000-03-03
+        #
+        # We detect when it's done that and correct for it here, ensuring that when a date rolls over
+        # to the next month, the date returned is always the first of that month. For more detail see:
+        # tests/spec/date_series/ops/test_date_series_ops.py::test_add_months
+        new_date_day = self.get_date_part(new_date, "DAY")
+        correction = sqlalchemy.case(
+            (
+                self.get_date_part(new_date, "DAY") < self.get_date_part(date, "DAY"),
+                1 - new_date_day,
+            ),
+            else_=0,
+        )
+        return self.date_add_days(new_date, correction)
+
+    def date_add_years(self, date, num_years):
+        return self.date_add("years", date, num_years)
+
+    def date_add(self, units, date, value):
+        value_str = sqlalchemy.cast(value, sqlalchemy.String)
+        modifier = value_str.concat(f" {units}")
+        return SQLFunction("DATE", date, modifier, type_=sqlalchemy.Date)
+
+    def to_first_of_year(self, date):
+        return SQLFunction("DATE", date, "start of year", type_=sqlalchemy.Date)
+
+    def to_first_of_month(self, date):
+        return SQLFunction("DATE", date, "start of month", type_=sqlalchemy.Date)
+
+    def get_aggregate_subquery(self, aggregate_function, columns, return_type):
+        # horrible edge-case where if a horizontal aggregate is called on
+        # a single literal, sqlite will only return the first row
+        if len(columns) == 1:
+            return columns[0]
+        # Sqlite returns null for greatest/least if any of the inputs are null
+        # Use cyclic coalescence to remove the nulls before applying the aggregate function
+        columns = get_cyclic_coalescence(columns)
+        return aggregate_function(*columns)
+
+    def get_measure_queries(self, grouped_sum, results_query):
+        """
+        Return the SQL queries to fetch the results for a GroupedSum representing
+        a collection of measures that share a denominator.
+        A GroupedSum contains:
+        - denominator: a single column to sum over
+        - numerators: a tuple of columns to sum over
+        - group_bys: a dict of tuples of columns to group by, and the numerators that each group by should be applied to
+
+        results_query is the result of calling get_queries on the dataset that
+        the measures will aggregate over.
+
+        In order to return a result that is the equivalent to using
+        GROUPING SETS, we take each collection of group-by columns (which would be a
+        grouping set in other SQL engines) and calculate the sums for the
+        denominator and the relevant numerator columns for this grouping set (there can
+        be more than one, if measures share a grouping set). Then we UNION ALL each individual
+        measure grouping.
+
+        For each grouping set, the value of GROUPING ID is an integer created by converting a
+        binary string of 0s and 1s for each group by column, where a 1 indicates
+        that the column is NOT a grouping column for that measure
+
+        e.g. we have 4 measures, and a total of 3 group by columns, [sex, region, ehnicity]
+        1) grouped by sex
+        2) grouped by region and ethnicity
+        3) grouped by sex, region and ethnicity
+
+        The grouping id for each of these would be:
+        1) 011 --> 3
+        2) 100 --> 4
+        3) 000 --> 0
+        """
+        measure_queries = []
+
+        # dict of column name to column select query for each group by column,
+        # maintaining the order of the columns
+        all_group_by_cols = {
+            col_name: results_query.c[col_name]
+            for col_name in ordered_set(iter_flatten(grouped_sum.group_bys))
+        }
+
+        denominator = sqlalchemy.func.sum(
+            results_query.c[grouped_sum.denominator]
+        ).label("den")
+
+        for group_bys, numerators in grouped_sum.group_bys.items():
+            # We need to return a column for each numerator in
+            # order to produce the same output columns as the base sql's grouping sets
+            # We don't actually need to calculate the sums multiple times though
+            sum_overs = [denominator] + [sqlalchemy.null] * (
+                len(grouped_sum.numerators)
+            )
+            # Now fill in the numerators that apply to this collection of group bys
+            for numerator in numerators:
+                numerator_index = grouped_sum.numerators.index(numerator)
+                sum_overs[numerator_index + 1] = sqlalchemy.func.sum(
+                    results_query.c[numerator]
+                ).label(f"num_{numerator}")
+            group_by_cols = [all_group_by_cols[col_name] for col_name in group_bys]
+            group_select_cols = [
+                all_group_by_cols[col_name]
+                if col_name in group_bys
+                else sqlalchemy.null
+                for col_name in all_group_by_cols
+            ]
+            grouping_id = get_grouping_level_as_int(
+                all_group_by_cols.values(), group_by_cols
+            )
+            measure_queries.append(
+                sqlalchemy.select(*sum_overs, *group_select_cols, grouping_id).group_by(
+                    *group_by_cols
+                )
+            )
+
+        return [sqlalchemy.union_all(*measure_queries)]