[e988c2]: / ehrql / query_engines / sqlite.py

Download this file

170 lines (147 with data), 8.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import sqlalchemy
from sqlalchemy.sql.functions import Function as SQLFunction
from ehrql.query_engines.base_sql import BaseSQLQueryEngine, get_cyclic_coalescence
from ehrql.query_engines.sqlite_dialect import SQLiteDialect
from ehrql.utils.itertools_utils import iter_flatten
from ehrql.utils.math_utils import get_grouping_level_as_int
from ehrql.utils.sequence_utils import ordered_set
class SQLiteQueryEngine(BaseSQLQueryEngine):
sqlalchemy_dialect = SQLiteDialect
def date_difference_in_days(self, end, start):
start_day = SQLFunction("JULIANDAY", start)
end_day = SQLFunction("JULIANDAY", end)
return sqlalchemy.cast(end_day - start_day, sqlalchemy.Integer)
def get_date_part(self, date, part):
format_str = {"YEAR": "%Y", "MONTH": "%m", "DAY": "%d"}[part]
part_as_str = SQLFunction("STRFTIME", format_str, date)
return sqlalchemy.cast(part_as_str, sqlalchemy.Integer)
def date_add_days(self, date, num_days):
return self.date_add("days", date, num_days)
def date_add_months(self, date, num_months):
new_date = self.date_add("months", date, num_months)
# In cases of day-of-month overflow, SQLite *usually* rolls over to the first of
# the next month as want it to.
# It does this by performing a normalisation when a date arithmetic operation results
# in a date with an invalid number of days (e.g. 30 Feb 2000); namely it
# rolls over to the next month by incrementing the month and subtracting
# the number of days in the month. For all months except February, the invalid day is only
# ever at most 1 day off (i.e. the 31st, for a month with only 30 days), and so this
# normalisation results in a rollover to the first of the next month. However for February,
# the invalid day can be 1 to 3 days off, which means date rollovers can result in the 1st,
# 2nd or 3rd March.
#
# The SQLite docs (https://sqlite.org/lang_datefunc.html) state:
# "Note that "±NNN months" works by rendering the original date into the YYYY-MM-DD
# format, adding the ±NNN to the MM month value, then normalizing the result. Thus,
# for example, the date 2001-03-31 modified by '+1 month' initially yields 2001-04-31,
# but April only has 30 days so the date is normalized to 2001-05-01"
#
# i.e. 2001-03-31 +1 month results in 2001-04-31, but as the number of days is invalid,
# SQLite rolls over to the next month by incrementing the month by 1, and subtracting
# the number of days in April (30) from the days, resulting in 2001-05-01
#
# In the case of February, a calculation that results in a date with an invalid number of
# days follows the same normalisation method:
# 2000-02-30: increment month by 1, subtract 29 (leap year) days -> 2000-03-01
# 2001-02-30: increment month by 1, subtract 28 days -> 2000-03-02
# 2000-02-31: increment month by 1, subtract 29 (leap year) days -> 2000-03-02
# 2001-02-31: increment month by 1, subtract 28 days -> 2000-03-03
#
# We detect when it's done that and correct for it here, ensuring that when a date rolls over
# to the next month, the date returned is always the first of that month. For more detail see:
# tests/spec/date_series/ops/test_date_series_ops.py::test_add_months
new_date_day = self.get_date_part(new_date, "DAY")
correction = sqlalchemy.case(
(
self.get_date_part(new_date, "DAY") < self.get_date_part(date, "DAY"),
1 - new_date_day,
),
else_=0,
)
return self.date_add_days(new_date, correction)
def date_add_years(self, date, num_years):
return self.date_add("years", date, num_years)
def date_add(self, units, date, value):
value_str = sqlalchemy.cast(value, sqlalchemy.String)
modifier = value_str.concat(f" {units}")
return SQLFunction("DATE", date, modifier, type_=sqlalchemy.Date)
def to_first_of_year(self, date):
return SQLFunction("DATE", date, "start of year", type_=sqlalchemy.Date)
def to_first_of_month(self, date):
return SQLFunction("DATE", date, "start of month", type_=sqlalchemy.Date)
def get_aggregate_subquery(self, aggregate_function, columns, return_type):
# horrible edge-case where if a horizontal aggregate is called on
# a single literal, sqlite will only return the first row
if len(columns) == 1:
return columns[0]
# Sqlite returns null for greatest/least if any of the inputs are null
# Use cyclic coalescence to remove the nulls before applying the aggregate function
columns = get_cyclic_coalescence(columns)
return aggregate_function(*columns)
def get_measure_queries(self, grouped_sum, results_query):
"""
Return the SQL queries to fetch the results for a GroupedSum representing
a collection of measures that share a denominator.
A GroupedSum contains:
- denominator: a single column to sum over
- numerators: a tuple of columns to sum over
- group_bys: a dict of tuples of columns to group by, and the numerators that each group by should be applied to
results_query is the result of calling get_queries on the dataset that
the measures will aggregate over.
In order to return a result that is the equivalent to using
GROUPING SETS, we take each collection of group-by columns (which would be a
grouping set in other SQL engines) and calculate the sums for the
denominator and the relevant numerator columns for this grouping set (there can
be more than one, if measures share a grouping set). Then we UNION ALL each individual
measure grouping.
For each grouping set, the value of GROUPING ID is an integer created by converting a
binary string of 0s and 1s for each group by column, where a 1 indicates
that the column is NOT a grouping column for that measure
e.g. we have 4 measures, and a total of 3 group by columns, [sex, region, ehnicity]
1) grouped by sex
2) grouped by region and ethnicity
3) grouped by sex, region and ethnicity
The grouping id for each of these would be:
1) 011 --> 3
2) 100 --> 4
3) 000 --> 0
"""
measure_queries = []
# dict of column name to column select query for each group by column,
# maintaining the order of the columns
all_group_by_cols = {
col_name: results_query.c[col_name]
for col_name in ordered_set(iter_flatten(grouped_sum.group_bys))
}
denominator = sqlalchemy.func.sum(
results_query.c[grouped_sum.denominator]
).label("den")
for group_bys, numerators in grouped_sum.group_bys.items():
# We need to return a column for each numerator in
# order to produce the same output columns as the base sql's grouping sets
# We don't actually need to calculate the sums multiple times though
sum_overs = [denominator] + [sqlalchemy.null] * (
len(grouped_sum.numerators)
)
# Now fill in the numerators that apply to this collection of group bys
for numerator in numerators:
numerator_index = grouped_sum.numerators.index(numerator)
sum_overs[numerator_index + 1] = sqlalchemy.func.sum(
results_query.c[numerator]
).label(f"num_{numerator}")
group_by_cols = [all_group_by_cols[col_name] for col_name in group_bys]
group_select_cols = [
all_group_by_cols[col_name]
if col_name in group_bys
else sqlalchemy.null
for col_name in all_group_by_cols
]
grouping_id = get_grouping_level_as_int(
all_group_by_cols.values(), group_by_cols
)
measure_queries.append(
sqlalchemy.select(*sum_overs, *group_select_cols, grouping_id).group_by(
*group_by_cols
)
)
return [sqlalchemy.union_all(*measure_queries)]