a b/ehrql/query_engines/sqlite.py
1
import sqlalchemy
2
from sqlalchemy.sql.functions import Function as SQLFunction
3
4
from ehrql.query_engines.base_sql import BaseSQLQueryEngine, get_cyclic_coalescence
5
from ehrql.query_engines.sqlite_dialect import SQLiteDialect
6
from ehrql.utils.itertools_utils import iter_flatten
7
from ehrql.utils.math_utils import get_grouping_level_as_int
8
from ehrql.utils.sequence_utils import ordered_set
9
10
11
class SQLiteQueryEngine(BaseSQLQueryEngine):
12
    sqlalchemy_dialect = SQLiteDialect
13
14
    def date_difference_in_days(self, end, start):
15
        start_day = SQLFunction("JULIANDAY", start)
16
        end_day = SQLFunction("JULIANDAY", end)
17
        return sqlalchemy.cast(end_day - start_day, sqlalchemy.Integer)
18
19
    def get_date_part(self, date, part):
20
        format_str = {"YEAR": "%Y", "MONTH": "%m", "DAY": "%d"}[part]
21
        part_as_str = SQLFunction("STRFTIME", format_str, date)
22
        return sqlalchemy.cast(part_as_str, sqlalchemy.Integer)
23
24
    def date_add_days(self, date, num_days):
25
        return self.date_add("days", date, num_days)
26
27
    def date_add_months(self, date, num_months):
28
        new_date = self.date_add("months", date, num_months)
29
        # In cases of day-of-month overflow, SQLite *usually* rolls over to the first of
30
        # the next month as want it to.
31
        # It does this by performing a normalisation when a date arithmetic operation results
32
        # in a date with an invalid number of days (e.g. 30 Feb 2000); namely it
33
        # rolls over to the next month by incrementing the month and subtracting
34
        # the number of days in the month. For all months except February, the invalid day is only
35
        # ever at most 1 day off (i.e. the 31st, for a month with only 30 days), and so this
36
        # normalisation results in a rollover to the first of the next month. However for February,
37
        # the invalid day can be 1 to 3 days off, which means date rollovers can result in the 1st,
38
        # 2nd or 3rd March.
39
        #
40
        # The SQLite docs (https://sqlite.org/lang_datefunc.html) state:
41
        # "Note that "±NNN months" works by rendering the original date into the YYYY-MM-DD
42
        # format, adding the ±NNN to the MM month value, then normalizing the result. Thus,
43
        # for example, the date 2001-03-31 modified by '+1 month' initially yields 2001-04-31,
44
        # but April only has 30 days so the date is normalized to 2001-05-01"
45
        #
46
        # i.e. 2001-03-31 +1 month results in 2001-04-31, but as the number of days is invalid,
47
        # SQLite rolls over to the next month by incrementing the month by 1, and subtracting
48
        # the number of days in April (30) from the days, resulting in 2001-05-01
49
        #
50
        # In the case of February, a calculation that results in a date with an invalid number of
51
        # days follows the same normalisation method:
52
        # 2000-02-30: increment month by 1, subtract 29 (leap year) days -> 2000-03-01
53
        # 2001-02-30: increment month by 1, subtract 28 days -> 2000-03-02
54
        # 2000-02-31: increment month by 1, subtract 29 (leap year) days -> 2000-03-02
55
        # 2001-02-31: increment month by 1, subtract 28 days -> 2000-03-03
56
        #
57
        # We detect when it's done that and correct for it here, ensuring that when a date rolls over
58
        # to the next month, the date returned is always the first of that month. For more detail see:
59
        # tests/spec/date_series/ops/test_date_series_ops.py::test_add_months
60
        new_date_day = self.get_date_part(new_date, "DAY")
61
        correction = sqlalchemy.case(
62
            (
63
                self.get_date_part(new_date, "DAY") < self.get_date_part(date, "DAY"),
64
                1 - new_date_day,
65
            ),
66
            else_=0,
67
        )
68
        return self.date_add_days(new_date, correction)
69
70
    def date_add_years(self, date, num_years):
71
        return self.date_add("years", date, num_years)
72
73
    def date_add(self, units, date, value):
74
        value_str = sqlalchemy.cast(value, sqlalchemy.String)
75
        modifier = value_str.concat(f" {units}")
76
        return SQLFunction("DATE", date, modifier, type_=sqlalchemy.Date)
77
78
    def to_first_of_year(self, date):
79
        return SQLFunction("DATE", date, "start of year", type_=sqlalchemy.Date)
80
81
    def to_first_of_month(self, date):
82
        return SQLFunction("DATE", date, "start of month", type_=sqlalchemy.Date)
83
84
    def get_aggregate_subquery(self, aggregate_function, columns, return_type):
85
        # horrible edge-case where if a horizontal aggregate is called on
86
        # a single literal, sqlite will only return the first row
87
        if len(columns) == 1:
88
            return columns[0]
89
        # Sqlite returns null for greatest/least if any of the inputs are null
90
        # Use cyclic coalescence to remove the nulls before applying the aggregate function
91
        columns = get_cyclic_coalescence(columns)
92
        return aggregate_function(*columns)
93
94
    def get_measure_queries(self, grouped_sum, results_query):
95
        """
96
        Return the SQL queries to fetch the results for a GroupedSum representing
97
        a collection of measures that share a denominator.
98
        A GroupedSum contains:
99
        - denominator: a single column to sum over
100
        - numerators: a tuple of columns to sum over
101
        - group_bys: a dict of tuples of columns to group by, and the numerators that each group by should be applied to
102
103
        results_query is the result of calling get_queries on the dataset that
104
        the measures will aggregate over.
105
106
        In order to return a result that is the equivalent to using
107
        GROUPING SETS, we take each collection of group-by columns (which would be a
108
        grouping set in other SQL engines) and calculate the sums for the
109
        denominator and the relevant numerator columns for this grouping set (there can
110
        be more than one, if measures share a grouping set). Then we UNION ALL each individual
111
        measure grouping.
112
113
        For each grouping set, the value of GROUPING ID is an integer created by converting a
114
        binary string of 0s and 1s for each group by column, where a 1 indicates
115
        that the column is NOT a grouping column for that measure
116
117
        e.g. we have 4 measures, and a total of 3 group by columns, [sex, region, ehnicity]
118
        1) grouped by sex
119
        2) grouped by region and ethnicity
120
        3) grouped by sex, region and ethnicity
121
122
        The grouping id for each of these would be:
123
        1) 011 --> 3
124
        2) 100 --> 4
125
        3) 000 --> 0
126
        """
127
        measure_queries = []
128
129
        # dict of column name to column select query for each group by column,
130
        # maintaining the order of the columns
131
        all_group_by_cols = {
132
            col_name: results_query.c[col_name]
133
            for col_name in ordered_set(iter_flatten(grouped_sum.group_bys))
134
        }
135
136
        denominator = sqlalchemy.func.sum(
137
            results_query.c[grouped_sum.denominator]
138
        ).label("den")
139
140
        for group_bys, numerators in grouped_sum.group_bys.items():
141
            # We need to return a column for each numerator in
142
            # order to produce the same output columns as the base sql's grouping sets
143
            # We don't actually need to calculate the sums multiple times though
144
            sum_overs = [denominator] + [sqlalchemy.null] * (
145
                len(grouped_sum.numerators)
146
            )
147
            # Now fill in the numerators that apply to this collection of group bys
148
            for numerator in numerators:
149
                numerator_index = grouped_sum.numerators.index(numerator)
150
                sum_overs[numerator_index + 1] = sqlalchemy.func.sum(
151
                    results_query.c[numerator]
152
                ).label(f"num_{numerator}")
153
            group_by_cols = [all_group_by_cols[col_name] for col_name in group_bys]
154
            group_select_cols = [
155
                all_group_by_cols[col_name]
156
                if col_name in group_bys
157
                else sqlalchemy.null
158
                for col_name in all_group_by_cols
159
            ]
160
            grouping_id = get_grouping_level_as_int(
161
                all_group_by_cols.values(), group_by_cols
162
            )
163
            measure_queries.append(
164
                sqlalchemy.select(*sum_overs, *group_select_cols, grouping_id).group_by(
165
                    *group_by_cols
166
                )
167
            )
168
169
        return [sqlalchemy.union_all(*measure_queries)]