Diff of /ehrql/query_language.py [000000] .. [e988c2]

Switch to unified view

a b/ehrql/query_language.py
1
import dataclasses
2
import datetime
3
import functools
4
import operator
5
import re
6
from collections import ChainMap
7
from collections.abc import Callable
8
from pathlib import Path
9
from typing import Any, Generic, TypeVar, overload
10
11
from ehrql.codes import BaseCode, BaseMultiCodeString
12
from ehrql.file_formats import read_rows
13
from ehrql.query_model import nodes as qm
14
from ehrql.query_model.column_specs import get_column_specs_from_schema
15
from ehrql.query_model.nodes import get_series_type, has_one_row_per_patient
16
from ehrql.query_model.population_validation import validate_population_definition
17
from ehrql.utils import date_utils
18
from ehrql.utils.string_utils import strip_indent
19
20
21
T = TypeVar("T")
22
CodeT = TypeVar("CodeT", bound=BaseCode)
23
MultiCodeStringT = TypeVar("MultiCodeStringT", bound=BaseMultiCodeString)
24
FloatT = TypeVar("FloatT", bound="FloatFunctions")
25
DateT = TypeVar("DateT", bound="DateFunctions")
26
IntT = TypeVar("IntT", bound="IntFunctions")
27
StrT = TypeVar("StrT", bound="StrFunctions")
28
29
VALID_ATTRIBUTE_NAME_RE = re.compile(r"^[A-Za-z]+[A-Za-z0-9_]*$")
30
31
# This gets populated by the `__init_subclass__` methods of EventSeries and
32
# PatientSeries. Its structure is:
33
#
34
#   (<type>, <is_patient_level>): <SeriesClass>
35
#
36
# For example:
37
#
38
#   (bool, False): BoolEventSeries,
39
#   (bool, True): BoolPatientSeries,
40
#
41
REGISTERED_TYPES = {}
42
43
44
class Error(Exception):
45
    """
46
    Used to translate errors from the query model into something more
47
    ehrQL-appropriate
48
    """
49
50
    # Pretend this exception is defined in the top-level `ehrql` module: this allows us
51
    # to avoid leaking the internal `query_language` module into the error messages
52
    # without creating circular import problems.
53
    __module__ = "ehrql"
54
55
56
@dataclasses.dataclass
57
class DummyDataConfig:
58
    population_size: int = 10
59
    legacy: bool = False
60
    timeout: int = 60
61
    additional_population_constraint: "qm.Series[bool] | None" = None
62
63
    def set_additional_population_constraint(self, additional_population_constraint):
64
        if additional_population_constraint is not None:
65
            validate_patient_series_type(
66
                additional_population_constraint,
67
                types=[bool],
68
                context="additional population constraint",
69
            )
70
            self.additional_population_constraint = (
71
                additional_population_constraint._qm_node
72
            )
73
        if self.legacy and self.additional_population_constraint is not None:
74
            raise ValueError(
75
                "Cannot provide an additional population constraint in legacy mode."
76
            )
77
78
79
class Dataset:
80
    """
81
    To create a dataset use the [`create_dataset`](#create_dataset) function.
82
    """
83
84
    def __init__(self):
85
        # Set attributes with `object.__setattr__` to avoid using the
86
        # `__setattr__` method on this class, which prohibits use of these
87
        # attribute names
88
        object.__setattr__(self, "_variables", {})
89
        object.__setattr__(self, "dummy_data_config", DummyDataConfig())
90
        object.__setattr__(self, "_events", {})
91
92
    def define_population(self, population_condition):
93
        """
94
        Define the condition that patients must meet to be included in the Dataset, in
95
        the form of a [boolean patient series](#BoolPatientSeries).
96
97
        Example usage:
98
        ```python
99
        dataset.define_population(patients.date_of_birth < "1990-01-01")
100
        ```
101
102
        For more detail see the how-to guide on [defining
103
        populations](../how-to/define-population.md).
104
        """
105
        if hasattr(self, "population"):
106
            raise AttributeError(
107
                "define_population() should be called no more than once"
108
            )
109
        validate_patient_series_type(
110
            population_condition,
111
            types=[bool],
112
            context="population definition",
113
        )
114
        try:
115
            validate_population_definition(population_condition._qm_node)
116
        except qm.ValidationError as exc:
117
            raise Error(str(exc)) from None
118
        object.__setattr__(self, "population", population_condition)
119
120
    def add_column(self, column_name: str, ehrql_query):
121
        """
122
        Add a column to the dataset.
123
124
        _column_name_<br>
125
        The name of the new column, as a string.
126
127
        _ehrql_query_<br>
128
        An ehrQL query that returns one row per patient.
129
130
        Example usage:
131
        ```python
132
        dataset.add_column("age", patients.age_on("2020-01-01"))
133
        ```
134
135
        Using `.add_column` is equivalent to `=` for adding a single column
136
        but can also be used to add multiple columns, for example by iterating
137
        over a dictionary. For more details see the guide on
138
        "[How to assign multiple columns to a dataset programmatically](../how-to/assign-multiple-columns.md)".
139
        """
140
        setattr(self, column_name, ehrql_query)
141
142
    def configure_dummy_data(
143
        self,
144
        *,
145
        population_size=DummyDataConfig.population_size,
146
        legacy=DummyDataConfig.legacy,
147
        timeout=DummyDataConfig.timeout,
148
        additional_population_constraint=None,
149
    ):
150
        """
151
        Configure the dummy data to be generated.
152
153
        _population_size_<br>
154
        Maximum number of patients to generate.
155
156
        Note that you may get fewer patients than this if the generator runs out of time
157
        – see `timeout` below.
158
159
        _legacy_<br>
160
        Use legacy dummy data.
161
162
        _timeout_<br>
163
        Maximum time in seconds to spend generating dummy data.
164
165
        _additional_population_constraint_<br>
166
        An additional ehrQL query that can be used to constrain the population that will
167
        be selected for dummy data. This is incompatible with legacy mode.
168
169
        For example, if you wanted to ensure that two dates appear in a particular order in your
170
        dummy data, you could add ``additional_population_constraint = dataset.first_date <
171
        dataset.second_date``.
172
173
        You can also combine constraints with ``&`` as normal in ehrQL.
174
        E.g. ``additional_population_constraint = patients.sex.is_in(['male', 'female']) & (
175
        patients.age_on(some_date) < 80)`` would give you dummy data consisting of only men
176
        and women who were under the age of 80 on some particular date.
177
178
        Example usage:
179
        ```python
180
        dataset.configure_dummy_data(population_size=10000)
181
        ```
182
        """
183
        self.dummy_data_config.population_size = population_size
184
        self.dummy_data_config.legacy = legacy
185
        self.dummy_data_config.timeout = timeout
186
        self.dummy_data_config.set_additional_population_constraint(
187
            additional_population_constraint
188
        )
189
190
    def __setattr__(self, name, value):
191
        if name == "population":
192
            raise AttributeError(
193
                "Cannot set variable 'population'; use define_population() instead"
194
            )
195
        _validate_attribute_name(
196
            name, self._variables | self._events, context="variable"
197
        )
198
        validate_patient_series(value, context=f"variable '{name}'")
199
        self._variables[name] = value
200
201
    def __getattr__(self, name):
202
        # Make this method accessible while hiding it from autocomplete until we make it
203
        # generally available
204
        if name == "add_event_table":
205
            return self._internal
206
        if name in self._variables:
207
            return self._variables[name]
208
        if name in self._events:
209
            return self._events[name]
210
        if name == "population":
211
            raise AttributeError(
212
                "A population has not been defined; define one with define_population()"
213
            )
214
        else:
215
            raise AttributeError(f"Variable '{name}' has not been defined")
216
217
    # This method ought to be called `add_event_table` but we're deliberately
218
    # obfuscating its name for now
219
    def _internal(self, name, **event_series):
220
        _validate_attribute_name(name, self._variables | self._events, context="table")
221
        self._events[name] = EventTable(self, **event_series)
222
223
    def _compile(self):
224
        return qm.Dataset(
225
            population=self.population._qm_node,
226
            variables={k: v._qm_node for k, v in self._variables.items()},
227
            events={k: v._qm_node for k, v in self._events.items()},
228
            measures=None,
229
        )
230
231
232
class EventTable:
233
    def __init__(self, dataset, **series):
234
        # Store reference to the parent dataset to aid debug rendering
235
        object.__setattr__(self, "_dataset", dataset)
236
        object.__setattr__(self, "_series", {})
237
        if not series:
238
            raise ValueError("event tables must be defined with at least one column")
239
        for name, value in series.items():
240
            self.add_column(name, value)
241
242
    def add_column(self, name, value):
243
        _validate_attribute_name(name, self._series, context="column")
244
        validate_ehrql_series(value, context=f"column {name!r}")
245
        try:
246
            qm_node = qm.SeriesCollectionFrame(
247
                {
248
                    name: series._qm_node
249
                    for name, series in (self._series | {name: value}).items()
250
                }
251
            )
252
        except qm.PatientDomainError:
253
            raise TypeError(
254
                "event tables must have columns with more than one value per patient; "
255
                "for single values per patient use dataset variables"
256
            )
257
        except qm.DomainMismatchError:
258
            raise Error(
259
                "cannot combine series drawn from different tables; "
260
                "create a new event table for these series"
261
            )
262
        self._series[name] = value
263
        object.__setattr__(self, "_qm_node", qm_node)
264
265
    def __setattr__(self, name, value):
266
        self.add_column(name, value)
267
268
    def __getattr__(self, name):
269
        return self._series[name]
270
271
272
def _validate_attribute_name(name, defined_names, context):
273
    if name in defined_names:
274
        raise AttributeError(f"'{name}' is already set and cannot be reassigned")
275
    if name in ("patient_id", "population", "dummy_data_config"):
276
        raise AttributeError(f"'{name}' is not an allowed {context} name")
277
    if not VALID_ATTRIBUTE_NAME_RE.match(name):
278
        raise AttributeError(
279
            f"{context} names must start with a letter, and contain only "
280
            f"alphanumeric characters and underscores (you defined a "
281
            f"{context} '{name}')"
282
        )
283
284
285
def create_dataset():
286
    """
287
    A dataset defines the patients you want to include in your population and the
288
    variables you want to extract for them.
289
290
    A dataset definition file must define a dataset called `dataset`:
291
292
    ```python
293
    dataset = create_dataset()
294
    ```
295
296
    Add variables to the dataset as attributes:
297
298
    ```python
299
    dataset.age = patients.age_on("2020-01-01")
300
    ```
301
    """
302
    return Dataset()
303
304
305
# BASIC SERIES TYPES
306
#
307
308
309
@dataclasses.dataclass(frozen=True)
310
class BaseSeries:
311
    _qm_node: qm.Node
312
313
    def __hash__(self):
314
        # The issue here is not mutability but the fact that we overload `__eq__` for
315
        # syntatic sugar, which makes these types spectacularly ill-behaved as dict keys
316
        raise TypeError(f"unhashable type: {self.__class__.__name__!r}")
317
318
    def __bool__(self):
319
        raise TypeError(
320
            "The keywords 'and', 'or', and 'not' cannot be used with ehrQL, please "
321
            "use the operators '&', '|' and '~' instead.\n"
322
            "(You will also see this error if you try use a chained comparison, "
323
            "such as 'a < b < c'.)"
324
        )
325
326
    @staticmethod
327
    def _cast(value):
328
        # Series have the opportunity to cast arguments to their methods e.g. to convert
329
        # ISO date strings to date objects. By default, this is a no-op.
330
        return value
331
332
    # These are the basic operations that apply to any series regardless of type or
333
    # dimension
334
    @overload
335
    def __eq__(self: "PatientSeries", other) -> "BoolPatientSeries": ...
336
    @overload
337
    def __eq__(self: "EventSeries", other) -> "BoolEventSeries": ...
338
339
    def __eq__(self, other):
340
        """
341
        Return a boolean series comparing each value in this series with its
342
        corresponding value in `other`.
343
344
        Note that the result of comparing anything with NULL (including NULL itself) is NULL.
345
346
        Example usage:
347
        ```python
348
        patients.sex == "female"
349
        ```
350
        """
351
        other = self._cast(other)
352
        return _apply(qm.Function.EQ, self, other)
353
354
    @overload
355
    def __ne__(self: "PatientSeries", other) -> "BoolPatientSeries": ...
356
    @overload
357
    def __ne__(self: "EventSeries", other) -> "BoolEventSeries": ...
358
    def __ne__(self, other):
359
        """
360
        Return the inverse of `==` above.
361
362
        Note that the same point regarding NULL applies here.
363
364
        Example usage:
365
        ```python
366
        patients.sex != "unknown"
367
        ```
368
        """
369
        other = self._cast(other)
370
        return _apply(qm.Function.NE, self, other)
371
372
    @overload
373
    def is_null(self: "PatientSeries") -> "BoolPatientSeries": ...
374
    @overload
375
    def is_null(self: "EventSeries") -> "BoolEventSeries": ...
376
    def is_null(self):
377
        """
378
        Return a boolean series which is True for each NULL value in this
379
        series and False for each non-NULL value.
380
381
        Example usage:
382
        ```python
383
        patients.date_of_death.is_null()
384
        ```
385
        """
386
        return _apply(qm.Function.IsNull, self)
387
388
    @overload
389
    def is_not_null(self: "PatientSeries") -> "BoolPatientSeries": ...
390
    @overload
391
    def is_not_null(self: "EventSeries") -> "BoolEventSeries": ...
392
    def is_not_null(self):
393
        """
394
        Return the inverse of `is_null()` above.
395
396
        Example usage:
397
        ```python
398
        patients.date_of_death.is_not_null()
399
        ```
400
        """
401
        return self.is_null().__invert__()
402
403
    def when_null_then(self: T, other: T) -> T:
404
        """
405
        Replace any NULL value in this series with the corresponding value in `other`.
406
407
        Note that `other` must be of the same type as this series.
408
409
        Example usage:
410
        ```python
411
        (patients.date_of_death < "2020-01-01").when_null_then(False)
412
        ```
413
        """
414
        return case(
415
            when(self.is_not_null()).then(self),
416
            otherwise=self._cast(other),
417
        )
418
419
    @overload
420
    def is_in(self: "PatientSeries", other) -> "BoolPatientSeries": ...
421
    @overload
422
    def is_in(self: "EventSeries", other) -> "BoolEventSeries": ...
423
    def is_in(self, other):
424
        """
425
        Return a boolean series which is True for each value in this series which is
426
        contained in `other`.
427
428
        See how to combine `is_in` with a codelist in
429
        [the how-to guide](../how-to/examples.md/#does-each-patient-have-a-clinical-event-matching-a-code-in-a-codelist).
430
431
        Example usage:
432
        ```python
433
        medications.dmd_code.is_in(["39113311000001107", "39113611000001102"])
434
        ```
435
436
        `other` accepts any of the standard "container" types (tuple, list, set, frozenset,
437
        or dict) or another event series.
438
        """
439
        if isinstance(other, tuple | list | set | frozenset | dict):
440
            # For iterable arguments, apply any necessary casting and convert to the
441
            # immutable Set type required by the query model. We don't accept arbitrary
442
            # iterables here because too many types in Python are iterable and there's
443
            # the potential for confusion amongst the less experienced of our users.
444
            other = frozenset(map(self._cast, other))
445
            return _apply(qm.Function.In, self, other)
446
        elif isinstance(other, EventSeries):
447
            # We have to use `_convert` and `_wrap` by hand here (rather than using
448
            # `_apply` which does this all for us) because we're constructing a
449
            # `CombineAsSet` query model object which doesn't have a representation in
450
            # the query language.
451
            other_as_set = qm.AggregateByPatient.CombineAsSet(_convert(other))
452
            return _wrap(qm.Function.In, _convert(self), other_as_set)
453
        elif isinstance(other, PatientSeries):
454
            raise TypeError(
455
                "Argument must be an EventSeries (i.e. have many values per patient); "
456
                "you supplied a PatientSeries with only one value per patient"
457
            )
458
        else:
459
            # If the argument is not a supported ehrQL type then we'll get an error here
460
            # (including hopefully helpful errors for common mistakes)
461
            _convert(other)
462
            # Otherwise it _is_ a supported type, but probably not of the right
463
            # cardinality
464
            raise TypeError(
465
                f"Invalid type: {type(other).__qualname__}\n"
466
                f"Note `is_in()` usually expects a list of values rather than a single value"
467
            )
468
469
    @overload
470
    def is_not_in(self: "PatientSeries", other) -> "BoolPatientSeries": ...
471
    @overload
472
    def is_not_in(self: "EventSeries", other) -> "BoolEventSeries": ...
473
    def is_not_in(self, other):
474
        """
475
        Return the inverse of `is_in()` above.
476
        """
477
        return self.is_in(other).__invert__()
478
479
    def map_values(self, mapping, default=None):
480
        """
481
        Return a new series with _mapping_ applied to each value. _mapping_ should
482
        be a dictionary mapping one set of values to another.
483
484
        Example usage:
485
        ```python
486
        school_year = patients.age_on("2020-09-01").map_values(
487
            {13: "Year 9", 14: "Year 10", 15: "Year 11"},
488
            default="N/A"
489
        )
490
        ```
491
        """
492
        return case(
493
            *[
494
                when(self == from_value).then(to_value)
495
                for from_value, to_value in mapping.items()
496
            ],
497
            otherwise=default,
498
        )
499
500
501
class PatientSeries(BaseSeries):
502
    def __init_subclass__(cls, **kwargs):
503
        super().__init_subclass__(**kwargs)
504
        # Register the series using its `_type` attribute
505
        REGISTERED_TYPES[cls._type, True] = cls
506
507
508
class EventSeries(BaseSeries):
509
    def __init_subclass__(cls, **kwargs):
510
        super().__init_subclass__(**kwargs)
511
        # Register the series using its `_type` attribute
512
        REGISTERED_TYPES[cls._type, False] = cls
513
514
    def count_distinct_for_patient(self) -> "IntPatientSeries":
515
        """
516
        Return an [integer patient series](#IntPatientSeries) counting the number of
517
        distinct values for each patient in the series (ignoring any NULL values).
518
519
        Note that if a patient has no values at all in the series the result will
520
        be zero rather than NULL.
521
522
        Example usage:
523
        ```python
524
        medications.dmd_code.count_distinct_for_patient()
525
        ```
526
        """
527
        return _apply(qm.AggregateByPatient.CountDistinct, self)
528
529
530
# BOOLEAN SERIES
531
#
532
533
534
class BoolFunctions:
535
    def __and__(self: T, other: T) -> T:
536
        """
537
        Logical AND
538
539
        Return a boolean series which is True where both this series and `other` are
540
        True, False where either are False, and NULL otherwise.
541
542
        Example usage:
543
        ```python
544
        is_female_and_alive = patients.is_alive_on("2020-01-01") & patients.sex.is_in(["female"])
545
        ```
546
        """
547
        other = self._cast(other)
548
        return _apply(qm.Function.And, self, other)
549
550
    def __or__(self: T, other: T) -> T:
551
        """
552
        Logical OR
553
554
        Return a boolean series which is True where either this series or `other` is
555
        True, False where both are False, and NULL otherwise.
556
557
        Example usage:
558
        ```python
559
        is_alive = patients.date_of_death.is_null() | patients.date_of_death.is_after("2020-01-01")
560
        ```
561
        Note that the above example is equivalent to `patients.is_alive_on("2020-01-01")`.
562
        """
563
        other = self._cast(other)
564
        return _apply(qm.Function.Or, self, other)
565
566
    def __invert__(self: T) -> T:
567
        """
568
        Logical NOT
569
570
        Return a boolean series which is the inverse of this series i.e. where True
571
        becomes False, False becomes True, and NULL stays as NULL.
572
573
        Example usage:
574
        ```python
575
        is_born_outside_period = ~ patients.date_of_birth.is_on_or_between("2020-03-01", "2020-06-30")
576
        ```
577
        """
578
        return _apply(qm.Function.Not, self)
579
580
    @overload
581
    def as_int(self: "PatientSeries") -> "IntPatientSeries": ...
582
    @overload
583
    def as_int(self: "EventSeries") -> "IntEventSeries": ...
584
    def as_int(self):
585
        """
586
        Return each value in this Boolean series as 1 (True) or 0 (False).
587
        """
588
        return _apply(qm.Function.CastToInt, self)
589
590
591
class BoolPatientSeries(BoolFunctions, PatientSeries):
592
    _type = bool
593
594
595
class BoolEventSeries(BoolFunctions, EventSeries):
596
    _type = bool
597
598
599
# METHODS COMMON TO ALL COMPARABLE TYPES
600
#
601
602
603
class ComparableFunctions:
604
    @overload
605
    def __lt__(self: "PatientSeries", other) -> "BoolPatientSeries": ...
606
    @overload
607
    def __lt__(self: "EventSeries", other) -> "BoolEventSeries": ...
608
    def __lt__(self, other):
609
        """
610
        Return a boolean series which is True for each value in this series that is
611
        strictly less than its corresponding value in `other` and False otherwise (or NULL
612
        if either value is NULL).
613
614
        Example usage:
615
        ```python
616
        is_underage = patients.age_on("2020-01-01") < 18
617
        ```
618
        """
619
        other = self._cast(other)
620
        return _apply(qm.Function.LT, self, other)
621
622
    @overload
623
    def __le__(self: "PatientSeries", other) -> "BoolPatientSeries": ...
624
    @overload
625
    def __le__(self: "EventSeries", other) -> "BoolEventSeries": ...
626
    def __le__(self, other):
627
        """
628
        Return a boolean series which is True for each value in this series that is less
629
        than or equal to its corresponding value in `other` and False otherwise (or NULL
630
        if either value is NULL).
631
632
        Example usage:
633
        ```python
634
        is_underage = patients.age_on("2020-01-01") <= 17
635
        ```
636
        """
637
        other = self._cast(other)
638
        return _apply(qm.Function.LE, self, other)
639
640
    @overload
641
    def __ge__(self: "PatientSeries", other) -> "BoolPatientSeries": ...
642
    @overload
643
    def __ge__(self: "EventSeries", other) -> "BoolEventSeries": ...
644
    def __ge__(self, other):
645
        """
646
        Return a boolean series which is True for each value in this series that is
647
        greater than or equal to its corresponding value in `other` and False otherwise
648
        (or NULL if either value is NULL).
649
650
        Example usage:
651
        ```python
652
        is_adult = patients.age_on("2020-01-01") >= 18
653
        ```
654
        """
655
        other = self._cast(other)
656
        return _apply(qm.Function.GE, self, other)
657
658
    @overload
659
    def __gt__(self: "PatientSeries", other) -> "BoolPatientSeries": ...
660
    @overload
661
    def __gt__(self: "EventSeries", other) -> "BoolEventSeries": ...
662
    def __gt__(self, other):
663
        """
664
        Return a boolean series which is True for each value in this series that is
665
        strictly greater than its corresponding value in `other` and False otherwise (or
666
        NULL if either value is NULL).
667
668
        Example usage:
669
        ```python
670
        is_adult = patients.age_on("2020-01-01") > 17
671
        ```
672
        """
673
        other = self._cast(other)
674
        return _apply(qm.Function.GT, self, other)
675
676
677
class ComparableAggregations:
678
    @overload
679
    def minimum_for_patient(self: DateT) -> "DatePatientSeries": ...
680
    @overload
681
    def minimum_for_patient(self: StrT) -> "StrPatientSeries": ...
682
    @overload
683
    def minimum_for_patient(self: IntT) -> "IntPatientSeries": ...
684
    @overload
685
    def minimum_for_patient(self: FloatT) -> "FloatPatientSeries": ...
686
    def minimum_for_patient(self):
687
        """
688
        Return the minimum value in the series for each patient (or NULL if the patient
689
        has no values).
690
691
        Example usage:
692
        ```python
693
        clinical_events.where(...).numeric_value.minimum_for_patient()
694
        ```
695
        """
696
        return _apply(qm.AggregateByPatient.Min, self)
697
698
    @overload
699
    def maximum_for_patient(self: DateT) -> "DatePatientSeries": ...
700
    @overload
701
    def maximum_for_patient(self: StrT) -> "StrPatientSeries": ...
702
    @overload
703
    def maximum_for_patient(self: IntT) -> "IntPatientSeries": ...
704
    @overload
705
    def maximum_for_patient(self: FloatT) -> "FloatPatientSeries": ...
706
    def maximum_for_patient(self):
707
        """
708
        Return the maximum value in the series for each patient (or NULL if the patient
709
        has no values).
710
711
        Example usage:
712
        ```python
713
        clinical_events.where(...).numeric_value.maximum_for_patient()
714
        ```
715
        """
716
        return _apply(qm.AggregateByPatient.Max, self)
717
718
719
# STRING SERIES
720
#
721
722
723
class StrFunctions(ComparableFunctions):
724
    @overload
725
    def contains(self: "PatientSeries", other) -> "BoolPatientSeries": ...
726
    @overload
727
    def contains(self: "EventSeries", other) -> "BoolEventSeries": ...
728
    def contains(self, other):
729
        """
730
        Return a boolean series which is True for each string in this series which
731
        contains `other` as a sub-string and False otherwise. For NULL values, the
732
        result is NULL.
733
734
        Example usage:
735
        ```python
736
        is_female = patients.sex.contains("fem")
737
        ```
738
739
        `other` can be another string series, in which case corresponding values
740
        are compared. If either value is NULL the result is NULL.
741
        """
742
        other = self._cast(other)
743
        return _apply(qm.Function.StringContains, self, other)
744
745
746
class StrAggregations(ComparableAggregations):
747
    "Empty for now"
748
749
750
class StrPatientSeries(StrFunctions, PatientSeries):
751
    _type = str
752
753
754
class StrEventSeries(StrFunctions, StrAggregations, EventSeries):
755
    _type = str
756
757
758
# NUMERIC SERIES
759
#
760
761
762
class NumericFunctions(ComparableFunctions):
763
    @overload
764
    def __add__(self: IntT, other: IntT | int) -> IntT: ...
765
    @overload
766
    def __add__(self: FloatT, other: FloatT | float) -> FloatT: ...
767
    def __add__(self, other):
768
        """
769
        Return the sum of each corresponding value in this series and `other` (or NULL
770
        if either is NULL).
771
        """
772
        other = self._cast(other)
773
        return _apply(qm.Function.Add, self, other)
774
775
    @overload
776
    def __radd__(self: IntT, other: IntT | int) -> IntT: ...
777
    @overload
778
    def __radd__(self: FloatT, other: FloatT | float) -> FloatT: ...
779
    def __radd__(self, other):
780
        return self + other
781
782
    @overload
783
    def __sub__(self: IntT, other: IntT | int) -> IntT: ...
784
    @overload
785
    def __sub__(self: FloatT, other: FloatT | float) -> FloatT: ...
786
    def __sub__(self, other):
787
        """
788
        Return each value in this series with its corresponding value in `other`
789
        subtracted (or NULL if either is NULL).
790
        """
791
        other = self._cast(other)
792
        return _apply(qm.Function.Subtract, self, other)
793
794
    @overload
795
    def __rsub__(self: IntT, other: IntT | int) -> IntT: ...
796
    @overload
797
    def __rsub__(self: FloatT, other: FloatT | float) -> FloatT: ...
798
    def __rsub__(self, other):
799
        return other + -self
800
801
    @overload
802
    def __mul__(self: IntT, other: IntT | int) -> IntT: ...
803
    @overload
804
    def __mul__(self: FloatT, other: FloatT | float) -> FloatT: ...
805
    def __mul__(self, other):
806
        """
807
        Return the product of each corresponding value in this series and `other` (or
808
        NULL if either is NULL).
809
        """
810
        other = self._cast(other)
811
        return _apply(qm.Function.Multiply, self, other)
812
813
    @overload
814
    def __rmul__(self: IntT, other: IntT | int) -> IntT: ...
815
    @overload
816
    def __rmul__(self: FloatT, other: FloatT | float) -> FloatT: ...
817
    def __rmul__(self, other):
818
        return self * other
819
820
    @overload
821
    def __truediv__(self: "PatientSeries", other) -> "FloatPatientSeries": ...
822
    @overload
823
    def __truediv__(self: "EventSeries", other) -> "FloatEventSeries": ...
824
    def __truediv__(self, other):
825
        """
826
        Return a series with each value in this series divided by its correponding value
827
        in `other` (or NULL if either is NULL).
828
829
        Note that the result is always if a float even if the inputs are integers.
830
        """
831
        other = self._cast(other)
832
        return _apply(qm.Function.TrueDivide, self, other)
833
834
    @overload
835
    def __rtruediv__(self: "PatientSeries", other) -> "FloatPatientSeries": ...
836
    @overload
837
    def __rtruediv__(self: "EventSeries", other) -> "FloatEventSeries": ...
838
    def __rtruediv__(self, other):
839
        return self / other
840
841
    @overload
842
    def __floordiv__(self: "PatientSeries", other) -> "IntPatientSeries": ...
843
    @overload
844
    def __floordiv__(self: "EventSeries", other) -> "IntEventSeries": ...
845
    def __floordiv__(self, other):
846
        """
847
        Return a series with each value in this series divided by its correponding value
848
        in `other` and then rounded **down** to the nearest integer value (or NULL if either
849
        is NULL).
850
851
        Note that the result is always if an integer even if the inputs are floats.
852
        """
853
        other = self._cast(other)
854
        return _apply(qm.Function.FloorDivide, self, other)
855
856
    @overload
857
    def __rfloordiv__(self: "PatientSeries", other) -> "IntPatientSeries": ...
858
    @overload
859
    def __rfloordiv__(self: "EventSeries", other) -> "IntEventSeries": ...
860
    def __rfloordiv__(self, other):
861
        return self // other
862
863
    def __neg__(self: T) -> T:
864
        """
865
        Return the negation of each value in this series.
866
        """
867
        return _apply(qm.Function.Negate, self)
868
869
    @overload
870
    def as_int(self: "PatientSeries") -> "IntPatientSeries": ...
871
    @overload
872
    def as_int(self: "EventSeries") -> "IntEventSeries": ...
873
    def as_int(self):
874
        """
875
        Return each value in this series rounded down to the nearest integer.
876
        """
877
        return _apply(qm.Function.CastToInt, self)
878
879
    @overload
880
    def as_float(self: "PatientSeries") -> "FloatPatientSeries": ...
881
    @overload
882
    def as_float(self: "EventSeries") -> "FloatEventSeries": ...
883
    def as_float(self):
884
        """
885
        Return each value in this series as a float (e.g. 10 becomes 10.0).
886
        """
887
        return _apply(qm.Function.CastToFloat, self)
888
889
890
class NumericAggregations(ComparableAggregations):
891
    @overload
892
    def sum_for_patient(self: FloatT) -> "FloatPatientSeries": ...
893
    @overload
894
    def sum_for_patient(self: IntT) -> "IntPatientSeries": ...
895
    def sum_for_patient(self):
896
        """
897
        Return the sum of all values in the series for each patient.
898
        """
899
        return _apply(qm.AggregateByPatient.Sum, self)
900
901
    def mean_for_patient(self) -> "FloatPatientSeries":
902
        """
903
        Return the arithmetic mean of any non-NULL values in the series for each
904
        patient.
905
        """
906
        return _apply(qm.AggregateByPatient.Mean, self)
907
908
909
class IntFunctions(NumericFunctions):
910
    "Currently only needed for type hints to easily tell the difference between int and float series"
911
912
913
class IntPatientSeries(IntFunctions, PatientSeries):
914
    _type = int
915
916
917
class IntEventSeries(IntFunctions, NumericAggregations, EventSeries):
918
    _type = int
919
920
921
class FloatFunctions(NumericFunctions):
922
    @staticmethod
923
    def _cast(value):
924
        """
925
        Casting int literals to floats. We do not support casting to float for IntSeries.
926
        """
927
        if isinstance(value, int):
928
            return float(value)
929
        return value
930
931
932
class FloatPatientSeries(FloatFunctions, PatientSeries):
933
    _type = float
934
935
936
class FloatEventSeries(FloatFunctions, NumericAggregations, EventSeries):
937
    _type = float
938
939
940
# DATE SERIES
941
#
942
943
944
def parse_date_if_str(value):
945
    if isinstance(value, str):
946
        # By default, `fromisoformat()` accepts the alternative YYYYMMDD format. We only
947
        # want to allow the hyphenated version so we pre-validate it.
948
        if not re.match(r"^\d{4}-\d{2}-\d{2}$", value):
949
            raise ValueError(f"Dates must be in YYYY-MM-DD format: {value!r}")
950
        try:
951
            return datetime.date.fromisoformat(value)
952
        except ValueError as e:
953
            raise ValueError(f"{e} in {value!r}") from None
954
    else:
955
        return value
956
957
958
def cast_all_arguments(args):
959
    series_args = [arg for arg in args if isinstance(arg, BaseSeries)]
960
    if series_args:
961
        # Choose the first series arbitrarily, the type checker will enforce that they
962
        # are all the same type in any case
963
        cast = series_args[0]._cast
964
        return tuple(map(cast, args))
965
    else:
966
        # This would only be the case if all the arguments were literals, which would be
967
        # an unusual and pointless bit of ehrQL, but we should handle it without error
968
        return args
969
970
971
# This allows us to get type hints for properties by replacing the
972
# @property decorator with this decorator. Currently only needed for
973
# ints. We pass the docstring through so that it can appear in the docs
974
class int_property(Generic[T]):
975
    def __init__(self, getter: Callable[[Any], T]) -> None:
976
        self.__doc__ = getter.__doc__
977
        self.getter = getter
978
979
    def __set__(self, instance, value): ...
980
981
    @overload
982
    def __get__(self, obj: PatientSeries, objtype=None) -> "IntPatientSeries": ...
983
984
    @overload
985
    def __get__(self, obj: EventSeries, objtype=None) -> "IntEventSeries": ...
986
987
    def __get__(self, obj, objtype=None):
988
        return self.getter(obj)
989
990
991
class DateFunctions(ComparableFunctions):
992
    @staticmethod
993
    def _cast(value):
994
        return parse_date_if_str(value)
995
996
    @int_property
997
    def year(self):
998
        """
999
        Return an integer series giving the year of each date in this series.
1000
        """
1001
        return _apply(qm.Function.YearFromDate, self)
1002
1003
    @int_property
1004
    def month(self):
1005
        """
1006
        Return an integer series giving the month (1-12) of each date in this series.
1007
        """
1008
        return _apply(qm.Function.MonthFromDate, self)
1009
1010
    @int_property
1011
    def day(self):
1012
        """
1013
        Return an integer series giving the day of the month (1-31) of each date in this
1014
        series.
1015
        """
1016
        return _apply(qm.Function.DayFromDate, self)
1017
1018
    def to_first_of_year(self: T) -> T:
1019
        """
1020
        Return a date series with each date in this series replaced by the date of the
1021
        first day in its corresponding calendar year.
1022
1023
        Example usage:
1024
        ```python
1025
        patients.date_of_death.to_first_of_year()
1026
        ```
1027
        """
1028
        return _apply(qm.Function.ToFirstOfYear, self)
1029
1030
    def to_first_of_month(self: T) -> T:
1031
        """
1032
        Return a date series with each date in this series replaced by the date of the
1033
        first day in its corresponding calendar month.
1034
1035
        Example usage:
1036
        ```python
1037
        patients.date_of_death.to_first_of_month()
1038
        ```
1039
        """
1040
        return _apply(qm.Function.ToFirstOfMonth, self)
1041
1042
    @overload
1043
    def is_before(self: PatientSeries, other) -> BoolPatientSeries: ...
1044
    @overload
1045
    def is_before(self: EventSeries, other) -> BoolEventSeries: ...
1046
    def is_before(self, other):
1047
        """
1048
        Return a boolean series which is True for each date in this series that is
1049
        strictly earlier than its corresponding date in `other` and False otherwise
1050
        (or NULL if either value is NULL).
1051
1052
        Example usage:
1053
        ```python
1054
        medications.where(medications.date.is_before("2020-04-01"))
1055
        ```
1056
        """
1057
        return self.__lt__(other)
1058
1059
    @overload
1060
    def is_on_or_before(self: PatientSeries, other) -> BoolPatientSeries: ...
1061
    @overload
1062
    def is_on_or_before(self: EventSeries, other) -> BoolEventSeries: ...
1063
    def is_on_or_before(self, other):
1064
        """
1065
        Return a boolean series which is True for each date in this series that is
1066
        earlier than or the same as its corresponding value in `other` and False
1067
        otherwise (or NULL if either value is NULL).
1068
1069
        Example usage:
1070
        ```python
1071
        medications.where(medications.date.is_on_or_before("2020-03-31"))
1072
        ```
1073
        """
1074
        return self.__le__(other)
1075
1076
    @overload
1077
    def is_after(self: PatientSeries, other) -> BoolPatientSeries: ...
1078
    @overload
1079
    def is_after(self: EventSeries, other) -> BoolEventSeries: ...
1080
    def is_after(self, other):
1081
        """
1082
        Return a boolean series which is True for each date in this series that is
1083
        strictly later than its corresponding date in `other` and False otherwise
1084
        (or NULL if either value is NULL).
1085
1086
        Example usage:
1087
        ```python
1088
        medications.where(medications.date.is_after("2020-03-31"))
1089
        ```
1090
        """
1091
        return self.__gt__(other)
1092
1093
    @overload
1094
    def is_on_or_after(self: PatientSeries, other) -> BoolPatientSeries: ...
1095
    @overload
1096
    def is_on_or_after(self: EventSeries, other) -> BoolEventSeries: ...
1097
    def is_on_or_after(self, other):
1098
        """
1099
        Return a boolean series which is True for each date in this series that is later
1100
        than or the same as its corresponding value in `other` and False otherwise (or
1101
        NULL if either value is NULL).
1102
1103
        Example usage:
1104
        ```python
1105
        medications.where(medications.date.is_on_or_after("2020-04-01"))
1106
        ```
1107
        """
1108
        return self.__ge__(other)
1109
1110
    @overload
1111
    def is_between_but_not_on(self: PatientSeries, start, end) -> BoolPatientSeries: ...
1112
    @overload
1113
    def is_between_but_not_on(self: EventSeries, start, end) -> BoolEventSeries: ...
1114
    def is_between_but_not_on(self, start, end):
1115
        """
1116
        Return a boolean series which is True for each date in this series which is
1117
        strictly between (i.e. not equal to) the corresponding dates in `start` and `end`,
1118
        and False otherwise.
1119
1120
        Example usage:
1121
        ```python
1122
        medications.where(medications.date.is_between_but_not_on("2020-03-31", "2021-04-01"))
1123
        ```
1124
        For each trio of dates being compared, if any date is NULL the result is NULL.
1125
        """
1126
        return (self > start) & (self < end)
1127
1128
    @overload
1129
    def is_on_or_between(self: PatientSeries, start, end) -> BoolPatientSeries: ...
1130
    @overload
1131
    def is_on_or_between(self: EventSeries, start, end) -> BoolEventSeries: ...
1132
    def is_on_or_between(self, start, end):
1133
        """
1134
        Return a boolean series which is True for each date in this series which is
1135
        between or the same as the corresponding dates in `start` and `end`, and
1136
        False otherwise.
1137
1138
        Example usage:
1139
        ```python
1140
        medications.where(medications.date.is_on_or_between("2020-04-01", "2021-03-31"))
1141
        ```
1142
        For each trio of dates being compared, if any date is NULL the result is NULL.
1143
        """
1144
        return (self >= start) & (self <= end)
1145
1146
    @overload
1147
    def is_during(self: PatientSeries, interval) -> BoolPatientSeries: ...
1148
    @overload
1149
    def is_during(self: EventSeries, interval) -> BoolEventSeries: ...
1150
    def is_during(self, interval):
1151
        """
1152
        The same as `is_on_or_between()` above, but allows supplying a start/end date
1153
        pair as single argument.
1154
1155
        Example usage:
1156
        ```python
1157
        study_period = ("2020-04-01", "2021-03-31")
1158
        medications.where(medications.date.is_during(study_period))
1159
        ```
1160
1161
        Also see the docs on using `is_during` with the
1162
        [`INTERVAL` placeholder](../explanation/measures.md/#the-interval-placeholder).
1163
        """
1164
        start, end = interval
1165
        return self.is_on_or_between(start, end)
1166
1167
    def __sub__(self, other):
1168
        """
1169
        Return a series giving the difference between each date in this series and
1170
        `other` (see [`DateDifference`](#DateDifference)).
1171
1172
        Example usage:
1173
        ```python
1174
        age_months = (date("2020-01-01") - patients.date_of_birth).months
1175
        ```
1176
        """
1177
        other = self._cast(other)
1178
        if isinstance(other, datetime.date | DateEventSeries | DatePatientSeries):
1179
            return DateDifference(self, other)
1180
        else:
1181
            return NotImplemented
1182
1183
    def __rsub__(self, other):
1184
        other = self._cast(other)
1185
        if isinstance(other, datetime.date | DateEventSeries | DatePatientSeries):
1186
            return DateDifference(other, self)
1187
        else:
1188
            return NotImplemented
1189
1190
1191
class DateAggregations(ComparableAggregations):
1192
    def count_episodes_for_patient(self, maximum_gap) -> IntPatientSeries:
1193
        """
1194
        Counts the number of "episodes" for each patient where dates which are no more
1195
        than `maximum_gap` apart are considered part of the same episode. The
1196
        `maximum_gap` duration can be specified in [`days()`](#days) or
1197
        [`weeks()`](#weeks).
1198
1199
        For example, suppose a patient has the following sequence of events:
1200
1201
        Event ID | Date
1202
        -- | --
1203
        A | 2020-01-01
1204
        B | 2020-01-04
1205
        C | 2020-01-06
1206
        D | 2020-01-10
1207
        E | 2020-01-12
1208
1209
        And suppose we count the episodes here using a maximum gap of three days:
1210
        ```python
1211
        .count_episodes_for_patient(days(3))
1212
        ```
1213
1214
        We will get an episode count of two: events A, B and C are considered as one
1215
        episode and events D and E as another.
1216
1217
        Note that events A and C are considered part of the same episode even though
1218
        they are more than three days apart because event B is no more than three days
1219
        apart from both of them. That is, the clock restarts with each new event in an
1220
        episode rather than running from the first event in an episode.
1221
        """
1222
        if isinstance(maximum_gap, days):
1223
            maximum_gap_days = maximum_gap.value
1224
        elif isinstance(maximum_gap, weeks):
1225
            maximum_gap_days = maximum_gap.value * 7
1226
        else:
1227
            raise TypeError("`maximum_gap` must be supplied as `days()` or `weeks()`")
1228
        if not isinstance(maximum_gap_days, int):
1229
            raise ValueError(
1230
                f"`maximum_gap` must be a single, fixed number of "
1231
                f"{type(maximum_gap).__name__}"
1232
            )
1233
        return _wrap(
1234
            qm.AggregateByPatient.CountEpisodes,
1235
            source=self._qm_node,
1236
            maximum_gap_days=maximum_gap_days,
1237
        )
1238
1239
1240
class DatePatientSeries(DateFunctions, PatientSeries):
1241
    _type = datetime.date
1242
1243
1244
class DateEventSeries(DateFunctions, DateAggregations, EventSeries):
1245
    _type = datetime.date
1246
1247
1248
# The default dataclass equality method doesn't work here and while we could define our
1249
# own it wouldn't be very useful for this type
1250
@dataclasses.dataclass(eq=False)
1251
class DateDifference:
1252
    """
1253
    Represents the difference between two dates or date series (i.e. it is what you
1254
    get when you perform subtractions on [DatePatientSeries](#DatePatientSeries.sub)
1255
    or [DateEventSeries](#DateEventSeries.sub)).
1256
    """
1257
1258
    lhs: datetime.date | DateEventSeries | DatePatientSeries
1259
    rhs: datetime.date | DateEventSeries | DatePatientSeries
1260
1261
    @property
1262
    def days(self):
1263
        """
1264
        The value of the date difference in days (can be positive or negative).
1265
        """
1266
        return _apply(qm.Function.DateDifferenceInDays, self.lhs, self.rhs)
1267
1268
    @property
1269
    def weeks(self):
1270
        """
1271
        The value of the date difference in whole weeks (can be positive or negative).
1272
        """
1273
        return self.days // 7
1274
1275
    @property
1276
    def months(self):
1277
        """
1278
        The value of the date difference in whole calendar months (can be positive or
1279
        negative).
1280
        """
1281
        return _apply(qm.Function.DateDifferenceInMonths, self.lhs, self.rhs)
1282
1283
    @property
1284
    def years(self):
1285
        """
1286
        The value of the date difference in whole calendar years (can be positive or
1287
        negative).
1288
        """
1289
        return _apply(qm.Function.DateDifferenceInYears, self.lhs, self.rhs)
1290
1291
1292
@dataclasses.dataclass
1293
class Duration:
1294
    value: int | IntEventSeries | IntPatientSeries
1295
1296
    def __init_subclass__(cls, **kwargs):
1297
        super().__init_subclass__(**kwargs)
1298
        assert cls._date_add_static is not None
1299
        assert cls._date_add_qm is not None
1300
1301
    # The default dataclass equality/inequality methods don't behave correctly here
1302
    def __eq__(self, other) -> bool:
1303
        """
1304
        Return True if `other` has the same value and units, and False otherwise.
1305
1306
        Hence, the result of `weeks(1) == days(7)` will be False.
1307
        """
1308
        if other.__class__ is not self.__class__:
1309
            return False
1310
        return self.value == other.value
1311
1312
    def __ne__(self, other) -> bool:
1313
        """
1314
        Return the inverse of `==` above.
1315
        """
1316
        # We have to apply different inversion logic depending on whether we have a
1317
        # boolean or a BoolSeries
1318
        is_equal = self == other
1319
        if isinstance(is_equal, bool):
1320
            return not is_equal
1321
        else:
1322
            return is_equal.__invert__()
1323
1324
    def __add__(self, other: T) -> T:
1325
        """
1326
        If `other` is a date or date series, add this duration to `other`
1327
        to produce a new date.
1328
1329
        If `other` is another duration with the same units, add the two durations
1330
        together to produce a new duration.
1331
        """
1332
        other = parse_date_if_str(other)
1333
        if isinstance(self.value, int) and isinstance(other, datetime.date):
1334
            # If both operands are static values we can perform the date arithmetic
1335
            # directly ourselves
1336
            return self._date_add_static(other, self.value)
1337
        elif isinstance(other, datetime.date | DateEventSeries | DatePatientSeries):
1338
            # Otherwise we create the appropriate query model construct
1339
            return _apply(self._date_add_qm, other, self.value)
1340
        elif isinstance(other, self.__class__):
1341
            # Durations of the same type can be added together
1342
            return self.__class__(self.value + other.value)
1343
        else:
1344
            # Nothing else is handled
1345
            return NotImplemented
1346
1347
    def __sub__(self, other: T) -> T:
1348
        """
1349
        Subtract `other` from this duration. `other` must be a
1350
        duration in the same units.
1351
        """
1352
        return self.__add__(other.__neg__())
1353
1354
    def __radd__(self, other: T) -> T:
1355
        return self.__add__(other)
1356
1357
    def __rsub__(self, other: T) -> T:
1358
        return self.__neg__().__add__(other)
1359
1360
    def __neg__(self: T) -> T:
1361
        """
1362
        Invert this duration, i.e. count the duration backwards in time
1363
        if it was originally forwards, and vice versa.
1364
        """
1365
        return self.__class__(self.value.__neg__())
1366
1367
    def starting_on(self, date) -> list[tuple[datetime.date, datetime.date]]:
1368
        """
1369
        Return a list of time intervals covering the duration starting on
1370
        `date`. Each interval lasts one unit.
1371
1372
        Example usage:
1373
        ```python
1374
        weeks(3).starting_on("2000-01-01")
1375
        ```
1376
        The above would return:
1377
        ```
1378
        [
1379
            (date(2000, 1, 1), date(2000, 1, 7)),
1380
            (date(2000, 1, 8), date(2000, 1, 14)),
1381
            (date(2000, 1, 15), date(2000, 1, 21)),
1382
        ]
1383
        ```
1384
1385
        Useful for generating the `intervals` arguments to [`Measures`](#Measures).
1386
        """
1387
        return self._generate_intervals(date, self.value, 1, "starting_on")
1388
1389
    def ending_on(self, date) -> list[tuple[datetime.date, datetime.date]]:
1390
        """
1391
        Return a list of time intervals covering the duration ending on
1392
        `date`. Each interval lasts one unit.
1393
1394
        Example usage:
1395
        ```python
1396
        weeks(3).ending_on("2000-01-21")
1397
        ```
1398
        The above would return:
1399
        ```
1400
        [
1401
            (date(2000, 1, 1), date(2000, 1, 7)),
1402
            (date(2000, 1, 8), date(2000, 1, 14)),
1403
            (date(2000, 1, 15), date(2000, 1, 21)),
1404
        ]
1405
        ```
1406
1407
        Useful for generating the `intervals` arguments to [`Measures`](#Measures).
1408
        """
1409
        return self._generate_intervals(date, self.value, -1, "ending_on")
1410
1411
    @classmethod
1412
    def _generate_intervals(cls, date, value, sign, method_name):
1413
        date = parse_date_if_str(date)
1414
        if not isinstance(date, datetime.date):
1415
            raise TypeError(
1416
                f"{cls.__name__}.{method_name}() can only be used with a literal "
1417
                f"date, not a date series"
1418
            )
1419
        if not isinstance(value, int):
1420
            raise TypeError(
1421
                f"{cls.__name__}.{method_name}() can only be used with a literal "
1422
                f"integer value, not an integer series"
1423
            )
1424
        if value < 0:
1425
            raise ValueError(
1426
                f"{cls.__name__}.{method_name}() can only be used with positive numbers"
1427
            )
1428
        return date_utils.generate_intervals(cls._date_add_static, date, value * sign)
1429
1430
1431
class days(Duration):
1432
    """
1433
    Represents a duration of time specified in days.
1434
1435
    Example usage:
1436
    ```python
1437
    last_medication_date = medications.sort_by(medications.date).last_for_patient().date
1438
    start_date = last_medication_date - days(90)
1439
    end_date = last_medication_date + days(90)
1440
    ```
1441
    """
1442
1443
    _date_add_static = staticmethod(date_utils.date_add_days)
1444
    _date_add_qm = qm.Function.DateAddDays
1445
1446
1447
class weeks(Duration):
1448
    """
1449
    Represents a duration of time specified in weeks.
1450
1451
    Example usage:
1452
    ```python
1453
    last_medication_date = medications.sort_by(medications.date).last_for_patient().date
1454
    start_date = last_medication_date - weeks(12)
1455
    end_date = last_medication_date + weeks(12)
1456
    ```
1457
    """
1458
1459
    _date_add_static = staticmethod(date_utils.date_add_weeks)
1460
1461
    @staticmethod
1462
    def _date_add_qm(date, num_weeks):
1463
        num_days = qm.Function.Multiply(num_weeks, qm.Value(7))
1464
        return qm.Function.DateAddDays(date, num_days)
1465
1466
1467
class months(Duration):
1468
    """
1469
    Represents a duration of time specified in calendar months.
1470
1471
    Example usage:
1472
    ```python
1473
    last_medication_date = medications.sort_by(medications.date).last_for_patient().date
1474
    start_date = last_medication_date - months(3)
1475
    end_date = last_medication_date + months(3)
1476
    ```
1477
1478
    Consider using [`days()`](#days) or [`weeks()`](#weeks) instead -
1479
    see the section on [Ambiguous Dates](#ambiguous-dates) for more.
1480
    """
1481
1482
    _date_add_static = staticmethod(date_utils.date_add_months)
1483
    _date_add_qm = qm.Function.DateAddMonths
1484
1485
1486
class years(Duration):
1487
    """
1488
    Represents a duration of time specified in calendar years.
1489
1490
    Example usage:
1491
    ```python
1492
    last_medication_date = medications.sort_by(medications.date).last_for_patient().date
1493
    start_date = last_medication_date - years(1)
1494
    end_date = last_medication_date + years(1)
1495
    ```
1496
1497
    Consider using [`days()`](#days) or [`weeks()`](#weeks) instead -
1498
    see the section on [Ambiguous Dates](#ambiguous-dates) for more.
1499
    """
1500
1501
    _date_add_static = staticmethod(date_utils.date_add_years)
1502
    _date_add_qm = qm.Function.DateAddYears
1503
1504
1505
# CODE SERIES
1506
#
1507
1508
1509
class CodeFunctions:
1510
    def _cast(self, value):
1511
        if isinstance(value, str):
1512
            return self._type(value)
1513
        else:
1514
            return value
1515
1516
    def to_category(self, categorisation, default=None):
1517
        """
1518
        An alias for `map_values` which makes the intention clearer when working with
1519
        codelists.
1520
1521
        For more detail see [`codelist_from_csv()`](#codelist_from_csv) and the
1522
        [how-to guide](../how-to/examples.md/#using-codelists-with-category-columns).
1523
        """
1524
        return self.map_values(categorisation, default=default)
1525
1526
1527
class CodePatientSeries(CodeFunctions, PatientSeries):
1528
    _type = BaseCode
1529
1530
1531
class CodeEventSeries(CodeFunctions, EventSeries):
1532
    _type = BaseCode
1533
1534
1535
class MultiCodeStringFunctions:
1536
    def _cast(self, value):
1537
        code_type = self._type._code_type()
1538
1539
        if isinstance(value, code_type):
1540
            # The passed code is of the expected type, so can convert to a string
1541
            return value._to_primitive_type()
1542
        elif isinstance(value, str) and self._type.regex.fullmatch(value):
1543
            # A string that matches the regex for this type
1544
            return value
1545
        else:
1546
            raise TypeError(
1547
                f"Expecting a {code_type}, or a string prefix of a {code_type}"
1548
            )
1549
1550
    def __eq__(self, other):
1551
        """
1552
        This operation is not allowed because it is unlikely you would want to match the
1553
        values in this field with an exact string e.g.
1554
        ```python
1555
        apcs.all_diagnoses == "||I302, K201, J180 || I302, K200, M920"
1556
        ```
1557
        Instead you should use the `contains` or `contains_any_of` methods.
1558
        """
1559
        raise TypeError(
1560
            "This column contains multiple clinical codes combined together in a single "
1561
            "string. If you want to know if a particular code is contained in the string, "
1562
            "please use the `contains()` method"
1563
        )
1564
1565
    def __ne__(self, other):
1566
        """
1567
        See above.
1568
        """
1569
        raise TypeError(
1570
            "This column contains multiple clinical codes combined together in a single "
1571
            "string. If you want to know if a particular code is not contained in the string, "
1572
            "please use the `contains()` method."
1573
        )
1574
1575
    def is_in(self, other):
1576
        """
1577
        This operation is not allowed. To check for the presence of any codes in
1578
        a codelist, please use the `contains_any_of(codelist)` method instead.
1579
        """
1580
        raise TypeError(
1581
            "You are attempting to use `.is_in()` on a column that contains multiple "
1582
            "clinical codes joined together. This is not allowed. If you want to know "
1583
            "if the field contains any of the codes from a codelist, then please use "
1584
            "`.contains_any_of(codelist)` instead."
1585
        )
1586
1587
    def is_not_in(self, other):
1588
        """
1589
        This operation is not allowed. To check for the absence of all codes in a codelist,
1590
        from a column called `column`, please use `~column.contains_any_of(codelist)`.
1591
        NB the `contains_any_of(codelist)` will provide any records that contain any of the
1592
        codes, which is then negated with the `~` operator.
1593
        """
1594
        raise TypeError(
1595
            "You are attempting to use `.is_not_in()` on a column that contains multiple "
1596
            "clinical codes joined together. This is not allowed. If you want to know "
1597
            "if the column does not contain any of the codes from a codelist, then please use "
1598
            "`~column.contains_any_of(codelist)` instead."
1599
        )
1600
1601
    @overload
1602
    def contains(self: PatientSeries, code) -> BoolPatientSeries: ...
1603
    @overload
1604
    def contains(self: EventSeries, code) -> BoolEventSeries: ...
1605
    def contains(self, code):
1606
        """
1607
        Check if the multi code field contains a specific code string and
1608
        return the result as a boolean series. `code` can
1609
        either be a string (and prefix matching works so e.g. "N17" in ICD-10
1610
        would match all acute renal failure), or a clinical code.
1611
1612
        Example usages:
1613
        ```python
1614
        all_diagnoses.contains("N17")
1615
        all_diagnoses.contains(ICD10Code("N170"))
1616
        ```
1617
        """
1618
        code = self._cast(code)
1619
        return _apply(qm.Function.StringContains, self, code)
1620
1621
    @overload
1622
    def contains_any_of(self: PatientSeries, codelist) -> BoolPatientSeries: ...
1623
    @overload
1624
    def contains_any_of(self: EventSeries, codelist) -> BoolEventSeries: ...
1625
    def contains_any_of(self, codelist):
1626
        """
1627
        Check if any of the codes in `codelist` occur in the multi code field and
1628
        return the result as a boolean series.
1629
        As with the `contains(code)` method, the codelist can be a mixture of clinical
1630
        codes and string prefixes, as seen in the example below.
1631
1632
        Example usage:
1633
        ```python
1634
        all_diagnoses.contains([ICD10Code("N170"), "N17"])
1635
        ```
1636
        """
1637
        conditions = [self.contains(code) for code in codelist]
1638
        return functools.reduce(operator.or_, conditions)
1639
1640
1641
class MultiCodeStringPatientSeries(MultiCodeStringFunctions, PatientSeries):
1642
    _type = BaseMultiCodeString
1643
1644
1645
class MultiCodeStringEventSeries(MultiCodeStringFunctions, EventSeries):
1646
    _type = BaseMultiCodeString
1647
1648
1649
# CONVERT QUERY MODEL SERIES TO EHRQL SERIES
1650
#
1651
1652
1653
def _wrap(qm_cls, *args, **kwargs):
1654
    """
1655
    Construct a query model series and wrap it in the ehrQL series class appropriate for
1656
    its type and dimension
1657
    """
1658
    qm_node = _build(qm_cls, *args, **kwargs)
1659
    type_ = get_series_type(qm_node)
1660
    is_patient_level = has_one_row_per_patient(qm_node)
1661
    try:
1662
        cls = REGISTERED_TYPES[type_, is_patient_level]
1663
        return cls(qm_node)
1664
    except KeyError:
1665
        # If we don't have a match for exactly this type then we should have one for a
1666
        # superclass. In the case where there are multiple matches, we want the narrowest
1667
        # match. E.g. for ICD10MultiCodeString which inherits from BaseMultiCodeString,
1668
        # which in turn inherits from str, we want to match BaseMultiCodeString as it
1669
        # corresponds to the "closest" series match (in this case MultiCodeStringEventSeries
1670
        # rather than the more generic StrEventSeries)
1671
        matches = [
1672
            {"cls": cls, "depth": type_.__mro__.index(target_type)}
1673
            for ((target_type, target_dimension), cls) in REGISTERED_TYPES.items()
1674
            if issubclass(type_, target_type) and is_patient_level == target_dimension
1675
        ]
1676
        assert matches, f"No matching query language class for {type_}"
1677
        matches.sort(key=lambda k: k["depth"])
1678
        cls = matches[0]["cls"]
1679
        wrapped = cls(qm_node)
1680
        wrapped._type = type_
1681
        return wrapped
1682
1683
1684
def _build(qm_cls, *args, **kwargs):
1685
    "Construct a query model node, translating any errors as appropriate"
1686
    try:
1687
        return qm_cls(*args, **kwargs)
1688
    except qm.InvalidSortError:
1689
        raise Error(
1690
            "Cannot sort by a constant value"
1691
            # Use `from None` to hide the chained exception
1692
        ) from None
1693
    except qm.DomainMismatchError:
1694
        hints = (
1695
            " * Reduce one series to have only one value per patient by using\n"
1696
            "   an aggregation like `maximum_for_patient()`.\n\n"
1697
            " * Pick a single row for each patient from the table using\n"
1698
            "   `first_for_patient()`."
1699
        )
1700
        if qm_cls is qm.Function.EQ:
1701
            hints = (
1702
                " * Use `x.is_in(y)` instead of `x == y` to check if values from\n"
1703
                "   one series match any of the patient's values in the other.\n\n"
1704
                f"{hints}"
1705
            )
1706
        raise Error(
1707
            "\n"
1708
            "Cannot combine series which are drawn from different tables and both\n"
1709
            "have more than one value per patient.\n"
1710
            "\n"
1711
            "To address this, try one of the following:\n"
1712
            "\n"
1713
            f"{hints}"
1714
            # Use `from None` to hide the chained exception
1715
        ) from None
1716
    except qm.TypeValidationError as exc:
1717
        # We deliberately omit information about the query model operation and field
1718
        # name here because these often don't match what's used in ehrQL and are liable
1719
        # to cause confusion
1720
        raise TypeError(
1721
            f"Expected type '{_format_typespec(exc.expected)}' "
1722
            f"but got '{_format_typespec(exc.received)}'"
1723
            # Use `from None` to hide the chained exception
1724
        ) from None
1725
1726
1727
def _format_typespec(typespec):
1728
    # At present we don't do anything beyond formatting as a string and then removing
1729
    # the module name prefix from "Series". It might be nice to remove mention of
1730
    # "Series" entirely here, but that's a task for another day.
1731
    return str(typespec).replace(f"{qm.__name__}.{qm.Series.__qualname__}", "Series")
1732
1733
1734
def _apply(qm_cls, *args):
1735
    """
1736
    Applies a query model operation `qm_cls` to its arguments which can be either ehrQL
1737
    series or static values, returns an ehrQL series
1738
    """
1739
    # Convert all arguments into query model nodes
1740
    qm_args = map(_convert, args)
1741
    # Construct the query model node and wrap it back up in an ehrQL series
1742
    return _wrap(qm_cls, *qm_args)
1743
1744
1745
def _convert(arg):
1746
    # Pass null values through unchanged
1747
    if arg is None:
1748
        return None
1749
    # Unpack tuple arguments
1750
    elif isinstance(arg, tuple):
1751
        return tuple(_convert(a) for a in arg)
1752
    # If it's an ehrQL series then get the wrapped query model node
1753
    elif isinstance(arg, BaseSeries):
1754
        return arg._qm_node
1755
    # If it's a static value then we need to be put in a query model Value wrapper
1756
    elif isinstance(
1757
        arg, bool | int | float | datetime.date | str | BaseCode | frozenset
1758
    ):
1759
        return qm.Value(arg)
1760
    else:
1761
        raise_helpful_error_if_possible(arg)
1762
        raise TypeError(f"Not a valid ehrQL type: {arg!r}")
1763
1764
1765
def Parameter(name, type_):
1766
    """
1767
    Return a parameter or placeholder series which can be used to construct a query
1768
    "template": a structure which can be turned into a query by substituting in concrete
1769
    values for any parameters it contains
1770
    """
1771
    return _wrap(qm.Parameter, name, type_)
1772
1773
1774
# FRAME TYPES
1775
#
1776
1777
1778
class BaseFrame:
1779
    def __init__(self, qm_node):
1780
        self._qm_node = qm_node
1781
1782
    def _select_column(self, name):
1783
        return _wrap(qm.SelectColumn, source=self._qm_node, name=name)
1784
1785
    def exists_for_patient(self) -> BoolPatientSeries:
1786
        """
1787
        Return a [boolean patient series](#BoolPatientSeries) which is True for each
1788
        patient that has a row in this frame and False otherwise.
1789
1790
        Example usage:
1791
        ```python
1792
        pratice_registrations.for_patient_on("2020-01-01").exists_for_patient()
1793
        ```
1794
        """
1795
        return _wrap(qm.AggregateByPatient.Exists, source=self._qm_node)
1796
1797
    def count_for_patient(self) -> IntPatientSeries:
1798
        """
1799
        Return an [integer patient series](#IntPatientSeries) giving the number of rows each
1800
        patient has in this frame.
1801
1802
        Note that if a patient has no rows at all in the frame the result will be zero
1803
        rather than NULL.
1804
1805
        Example usage:
1806
        ```python
1807
        clinical_events.where(clinical_events.date.year == 2020).count_for_patient()
1808
        ```
1809
        """
1810
        return _wrap(qm.AggregateByPatient.Count, source=self._qm_node)
1811
1812
1813
class PatientFrame(BaseFrame):
1814
    """
1815
    Frame containing at most one row per patient.
1816
    """
1817
1818
1819
class EventFrame(BaseFrame):
1820
    """
1821
    Frame which may contain multiple rows per patient.
1822
    """
1823
1824
    def where(self, condition):
1825
        """
1826
        Return a new frame containing only the rows in this frame for which `condition`
1827
        evaluates True.
1828
1829
        Note that this excludes any rows for which `condition` is NULL.
1830
1831
        Example usage:
1832
        ```python
1833
        clinical_events.where(clinical_events.date >= "2020-01-01")
1834
        ```
1835
        """
1836
        return self.__class__(
1837
            qm.Filter(
1838
                source=self._qm_node,
1839
                condition=_convert(condition),
1840
            )
1841
        )
1842
1843
    def except_where(self, condition):
1844
        """
1845
        Return a new frame containing only the rows in this frame for which `condition`
1846
        evaluates False or NULL i.e. the exact inverse of the rows included by
1847
        `where()`.
1848
1849
        Example usage:
1850
        ```python
1851
        practice_registrations.except_where(practice_registrations.end_date < "2020-01-01")
1852
        ```
1853
1854
        Note that `except_where()` is not the same as `where()` with an inverted condition,
1855
        as the latter would exclude rows where `condition` is NULL.
1856
        """
1857
        return self.__class__(
1858
            qm.Filter(
1859
                source=self._qm_node,
1860
                condition=qm.Function.Or(
1861
                    lhs=qm.Function.Not(_convert(condition)),
1862
                    rhs=qm.Function.IsNull(_convert(condition)),
1863
                ),
1864
            )
1865
        )
1866
1867
    def sort_by(self, *sort_values):
1868
        """
1869
        Return a new frame with the rows sorted for each patient, by
1870
        each of the supplied `sort_values`.
1871
1872
        Where more than one sort value is supplied then the first (i.e. left-most) value
1873
        has highest priority and each subsequent sort value will only be used as a
1874
        tie-breaker in case of an exact match among previous values.
1875
1876
        Note that NULL is considered smaller than any other value, so you may wish to
1877
        filter out NULL values before sorting.
1878
1879
        Example usage:
1880
        ```python
1881
        clinical_events.sort_by(clinical_events.date, clinical_events.snomedct_code)
1882
        ```
1883
        """
1884
        # Raise helpful error for easy form of mistake
1885
        if string_arg := next((v for v in sort_values if isinstance(v, str)), None):
1886
            raise TypeError(
1887
                f"to sort by a column use a table attribute like "
1888
                f"`{self.__class__.__name__}.{string_arg}` rather than the string "
1889
                f'"{string_arg}"'
1890
            )
1891
1892
        qm_node = self._qm_node
1893
        # We expect series to be supplied highest priority first and, as the most
1894
        # recently applied Sort operation has the highest priority, we need to apply
1895
        # them in reverse order
1896
        for series in reversed(sort_values):
1897
            qm_node = _build(
1898
                qm.Sort,
1899
                source=qm_node,
1900
                sort_by=_convert(series),
1901
            )
1902
        cls = make_sorted_event_frame_class(self.__class__)
1903
        return cls(qm_node)
1904
1905
1906
class SortedEventFrameMethods:
1907
    def first_for_patient(self):
1908
        """
1909
        Return a PatientFrame containing, for each patient, the first matching row
1910
        according to whatever sort order has been applied.
1911
1912
        Note that where there are multiple rows tied for first place then the specific
1913
        row returned is picked arbitrarily but consistently i.e. you shouldn't depend on
1914
        getting any particular result, but the result you do get shouldn't change unless
1915
        the data changes.
1916
1917
        Example usage:
1918
        ```python
1919
        medications.sort_by(medications.date).first_for_patient()
1920
        ```
1921
        """
1922
        cls = make_patient_frame_class(self.__class__)
1923
        return cls(
1924
            qm.PickOneRowPerPatient(
1925
                position=qm.Position.FIRST,
1926
                source=self._qm_node,
1927
            )
1928
        )
1929
1930
    def last_for_patient(self):
1931
        """
1932
        Return a PatientFrame containing, for each patient, the last matching row
1933
        according to whatever sort order has been applied.
1934
1935
        Note that where there are multiple rows tied for last place then the specific
1936
        row returned is picked arbitrarily but consistently i.e. you shouldn't depend on
1937
        getting any particular result, but the result you do get shouldn't change unless
1938
        the data changes.
1939
1940
        Example usage:
1941
        ```python
1942
        medications.sort_by(medications.date).last_for_patient()
1943
        ```
1944
        """
1945
        cls = make_patient_frame_class(self.__class__)
1946
        return cls(
1947
            qm.PickOneRowPerPatient(
1948
                position=qm.Position.LAST,
1949
                source=self._qm_node,
1950
            )
1951
        )
1952
1953
1954
@functools.cache
1955
def make_sorted_event_frame_class(cls):
1956
    """
1957
    Given a class return a subclass which has the SortedEventFrameMethods
1958
    """
1959
    if issubclass(cls, SortedEventFrameMethods):
1960
        return cls
1961
    else:
1962
        return type(cls.__name__, (SortedEventFrameMethods, cls), {})
1963
1964
1965
@functools.cache
1966
def make_patient_frame_class(cls):
1967
    """
1968
    Given an EventFrame subclass return a PatientFrame subclass with the same columns as
1969
    the original frame
1970
    """
1971
    return type(
1972
        cls.__name__,
1973
        (PatientFrame,),
1974
        get_all_series_and_properties_from_class(cls),
1975
    )
1976
1977
1978
def get_all_series_from_class(cls):
1979
    # Because `Series` is a descriptor we can't access the column objects via class
1980
    # attributes without invoking the descriptor: instead, we have to access them using
1981
    # `vars()`. But `vars()` only gives us attributes defined directly on the class, not
1982
    # inherited ones. So we reproduce the inheritance behaviour using `ChainMap`.
1983
    #
1984
    # This is _almost_ exactly what `inspect.getmembers_static` does except that returns
1985
    # attributes in lexical order whereas we want to return the original definition
1986
    # order.
1987
    attrs = ChainMap(*[vars(base) for base in cls.__mro__])
1988
    return {key: value for key, value in attrs.items() if isinstance(value, Series)}
1989
1990
1991
def get_all_series_and_properties_from_class(cls):
1992
    # Repeating the logic above but also capturing items with the @property decorator.
1993
    # This is necessary so we can have properties as well as Series on tables. Keeping
1994
    # the other function as there are still other uses where we just want the Series.
1995
    attrs = ChainMap(*[vars(base) for base in cls.__mro__])
1996
    return {
1997
        key: value
1998
        for key, value in attrs.items()
1999
        if isinstance(value, Series | property)
2000
    }
2001
2002
2003
# FRAME CONSTRUCTOR ENTRYPOINTS
2004
#
2005
2006
2007
# A class decorator which replaces the class definition with an appropriately configured
2008
# instance of the class. Obviously this is a _bit_ odd, but I think worth it overall.
2009
# Using classes to define tables is (as far as I can tell) the only way to get nice
2010
# autocomplete and type-checking behaviour for column names. But we don't actually want
2011
# these classes accessible anywhere: users should only be interacting with instances of
2012
# the classes, and having the classes themselves in the module namespaces only makes
2013
# autocomplete more confusing and error prone.
2014
def table(cls: type[T]) -> T:
2015
    if PatientFrame in cls.__mro__:
2016
        qm_class = qm.SelectPatientTable
2017
    elif EventFrame in cls.__mro__:
2018
        qm_class = qm.SelectTable
2019
    else:
2020
        raise Error("Schema class must subclass either `PatientFrame` or `EventFrame`")
2021
2022
    qm_node = qm_class(
2023
        name=cls.__name__,
2024
        schema=get_table_schema_from_class(cls),
2025
    )
2026
    return cls(qm_node)
2027
2028
2029
def get_table_schema_from_class(cls):
2030
    # Get all `Series` objects on the class and determine the schema from them
2031
    schema = {
2032
        series.name: qm.Column(series.type_, constraints=series.constraints)
2033
        for series in get_all_series_from_class(cls).values()
2034
    }
2035
    return qm.TableSchema(**schema)
2036
2037
2038
# Defines a PatientFrame along with the data it contains. Takes a list (or
2039
# any iterable) of row tuples of the form:
2040
#
2041
#    (patient_id, column_1_in_schema, column_2_in_schema, ...)
2042
#
2043
def table_from_rows(rows):
2044
    def decorator(cls):
2045
        if cls.__bases__ != (PatientFrame,):
2046
            raise Error("`@table_from_rows` can only be used with `PatientFrame`")
2047
        qm_node = qm.InlinePatientTable(
2048
            rows=tuple(rows),
2049
            schema=get_table_schema_from_class(cls),
2050
        )
2051
        return cls(qm_node)
2052
2053
    return decorator
2054
2055
2056
# Defines a PatientFrame along with the data it contains. Takes a path to
2057
# a file (feather, csv, csv.gz) with rows of the form:
2058
#
2059
#    (patient_id, column_1_in_schema, column_2_in_schema, ...)
2060
#
2061
def table_from_file(path):
2062
    path = Path(path)
2063
2064
    def decorator(cls):
2065
        if cls.__bases__ != (PatientFrame,):
2066
            raise Error("`@table_from_file` can only be used with `PatientFrame`")
2067
2068
        schema = get_table_schema_from_class(cls)
2069
        column_specs = get_column_specs_from_schema(schema)
2070
2071
        rows = read_rows(path, column_specs)
2072
2073
        qm_node = qm.InlinePatientTable(
2074
            rows=rows,
2075
            schema=get_table_schema_from_class(cls),
2076
        )
2077
        return cls(qm_node)
2078
2079
    return decorator
2080
2081
2082
# A descriptor which will return the appropriate type of series depending on the type of
2083
# frame it belongs to i.e. a PatientSeries subclass for PatientFrames and an EventSeries
2084
# subclass for EventFrames. This lets schema authors use a consistent syntax when
2085
# defining frames of either type.
2086
class Series(Generic[T]):
2087
    def __init__(
2088
        self,
2089
        type_: type[T],
2090
        *,
2091
        description="",
2092
        constraints=(),
2093
        required=True,
2094
        implementation_notes_to_add_to_description="",
2095
        notes_for_implementors="",
2096
    ):
2097
        self.type_ = type_
2098
        self.description = strip_indent(description)
2099
        self.constraints = constraints
2100
        self.required = required
2101
        self.implementation_notes_to_add_to_description = strip_indent(
2102
            implementation_notes_to_add_to_description
2103
        )
2104
        self.notes_for_implementors = strip_indent(notes_for_implementors)
2105
2106
    def __set_name__(self, owner, name):
2107
        self.name = name
2108
2109
    @overload
2110
    def __get__(
2111
        self: "Series[datetime.date]", instance: PatientFrame, owner
2112
    ) -> "DatePatientSeries": ...
2113
2114
    @overload
2115
    def __get__(
2116
        self: "Series[datetime.date]", instance: EventFrame, owner
2117
    ) -> DateEventSeries: ...
2118
2119
    @overload
2120
    def __get__(
2121
        self: "Series[CodeT]", instance: PatientFrame, owner
2122
    ) -> CodePatientSeries: ...
2123
2124
    @overload
2125
    def __get__(
2126
        self: "Series[CodeT]", instance: EventFrame, owner
2127
    ) -> CodeEventSeries: ...
2128
2129
    @overload
2130
    def __get__(
2131
        self: "Series[MultiCodeStringT]", instance: PatientFrame, owner
2132
    ) -> MultiCodeStringPatientSeries: ...
2133
2134
    @overload
2135
    def __get__(
2136
        self: "Series[MultiCodeStringT]", instance: EventFrame, owner
2137
    ) -> MultiCodeStringEventSeries: ...
2138
2139
    @overload
2140
    def __get__(
2141
        self: "Series[bool]", instance: PatientFrame, owner
2142
    ) -> BoolPatientSeries: ...
2143
2144
    @overload
2145
    def __get__(
2146
        self: "Series[bool]", instance: EventFrame, owner
2147
    ) -> BoolEventSeries: ...
2148
2149
    @overload
2150
    def __get__(
2151
        self: "Series[str]", instance: PatientFrame, owner
2152
    ) -> StrPatientSeries: ...
2153
2154
    @overload
2155
    def __get__(
2156
        self: "Series[str]", instance: EventFrame, owner
2157
    ) -> "StrEventSeries": ...
2158
2159
    @overload
2160
    def __get__(
2161
        self: "Series[int]", instance: PatientFrame, owner
2162
    ) -> IntPatientSeries: ...
2163
2164
    @overload
2165
    def __get__(self: "Series[int]", instance: EventFrame, owner) -> IntEventSeries: ...
2166
2167
    @overload
2168
    def __get__(
2169
        self: "Series[float]", instance: PatientFrame, owner
2170
    ) -> FloatPatientSeries: ...
2171
2172
    @overload
2173
    def __get__(
2174
        self: "Series[float]", instance: EventFrame, owner
2175
    ) -> FloatEventSeries: ...
2176
2177
    def __get__(self, instance, owner):
2178
        if instance is None:  # pragma: no cover
2179
            return self
2180
        return instance._select_column(self.name)
2181
2182
2183
def get_tables_from_namespace(namespace):
2184
    """
2185
    Yield all ehrQL tables contained in `namespace`
2186
    """
2187
    for attr, value in vars(namespace).items():
2188
        if isinstance(value, BaseFrame):
2189
            yield attr, value
2190
2191
2192
# CASE EXPRESSION FUNCTIONS
2193
#
2194
2195
2196
class when:
2197
    def __init__(self, condition):
2198
        condition_qm = _convert(condition)
2199
        type_ = get_series_type(condition_qm)
2200
        if type_ is not bool:
2201
            raise TypeError(
2202
                f"invalid case condition:\n"
2203
                f"Expecting a boolean series, got series of type"
2204
                f" '{type_.__qualname__}'",
2205
            )
2206
        self._condition = condition_qm
2207
2208
    def then(self, value):
2209
        return WhenThen(self._condition, _convert(value))
2210
2211
2212
class WhenThen:
2213
    def __init__(self, condition, value):
2214
        self._condition = condition
2215
        self._value = value
2216
2217
    def otherwise(self, value):
2218
        return case(self, otherwise=value)
2219
2220
2221
def case(*when_thens, otherwise=None):
2222
    """
2223
    Take a sequence of condition-values of the form:
2224
    ```python
2225
    when(condition).then(value)
2226
    ```
2227
2228
    And evaluate them in order, returning the value of the first condition which
2229
    evaluates True. If no condition matches, return the `otherwise` value (or NULL
2230
    if no `otherwise` value is specified).
2231
2232
    Example usage:
2233
    ```python
2234
    category = case(
2235
        when(size < 10).then("small"),
2236
        when(size < 20).then("medium"),
2237
        when(size >= 20).then("large"),
2238
        otherwise="unknown",
2239
    )
2240
    ```
2241
2242
    Note that because the conditions are evaluated in order we don't need the condition
2243
    for "medium" to specify `(size >= 10) & (size < 20)` because by the time the
2244
    condition for "medium" is being evaluated we already know the condition for "small"
2245
    is False.
2246
2247
    A simpler form is available when there is a single condition.  This example:
2248
    ```python
2249
    category = case(
2250
        when(size < 15).then("small"),
2251
        otherwise="large",
2252
    )
2253
    ```
2254
2255
    can be rewritten as:
2256
    ```python
2257
    category = when(size < 15).then("small").otherwise("large")
2258
    ```
2259
    """
2260
    cases = {}
2261
    for case in when_thens:
2262
        if isinstance(case, when):
2263
            raise TypeError(
2264
                "`when(...)` clause missing a `.then(...)` value in `case()` expression"
2265
            )
2266
        elif (
2267
            isinstance(case, BaseSeries)
2268
            and isinstance(case._qm_node, qm.Case)
2269
            and len(case._qm_node.cases) == 1
2270
        ):
2271
            raise TypeError(
2272
                "invalid syntax for `otherwise` in `case()` expression, instead of:\n"
2273
                "\n"
2274
                "    case(\n"
2275
                "        when(...).then(...).otherwise(...)\n"
2276
                "    )\n"
2277
                "\n"
2278
                "You should write:\n"
2279
                "\n"
2280
                "    case(\n"
2281
                "        when(...).then(...),\n"
2282
                "        otherwise=...\n"
2283
                "    )\n"
2284
                "\n"
2285
            )
2286
        elif not isinstance(case, WhenThen):
2287
            raise TypeError(
2288
                "cases must be specified in the form:\n"
2289
                "\n"
2290
                "    when(<CONDITION>).then(<VALUE>)\n"
2291
                "\n"
2292
            )
2293
        elif case._condition in cases:
2294
            raise TypeError("duplicated condition in `case()` expression")
2295
        else:
2296
            cases[case._condition] = case._value
2297
    if not cases:
2298
        raise TypeError("`case()` expression requires at least one case")
2299
    if otherwise is None and all(value is None for value in cases.values()):
2300
        raise TypeError("`case()` expression cannot have all `None` values")
2301
    return _wrap(qm.Case, cases, default=_convert(otherwise))
2302
2303
2304
# HORIZONTAL AGGREGATION FUNCTIONS
2305
#
2306
# These cast all arguments to the first Series. So if we have a Series as
2307
# the first arg then we know the return type. However, if the first arg is
2308
# not a Series, then we don't know the return type. E.g. the following examples
2309
# are tricky:
2310
# maximum_of(10, 10, clinical_events.numeric_value) - will return FloatEventSeries
2311
# maximum_of("2024-01-01", "2023-01-01", clinical_events.date) - will return DateEventSeries
2312
@overload
2313
def maximum_of(value: IntT, other_value, *other_values) -> IntT: ...
2314
@overload
2315
def maximum_of(value: FloatT, other_value, *other_values) -> FloatT: ...
2316
@overload
2317
def maximum_of(value: DateT, other_value, *other_values) -> DateT: ...
2318
def maximum_of(value, other_value, *other_values) -> int:
2319
    """
2320
    Return the maximum value of a collection of Series or Values, disregarding NULLs.
2321
2322
    Example usage:
2323
    ```python
2324
    latest_event_date = maximum_of(event_series_1.date, event_series_2.date, "2001-01-01")
2325
    ```
2326
    """
2327
    args = cast_all_arguments((value, other_value, *other_values))
2328
    return _apply(qm.Function.MaximumOf, args)
2329
2330
2331
@overload
2332
def minimum_of(value: IntT, other_value, *other_values) -> IntT: ...
2333
@overload
2334
def minimum_of(value: FloatT, other_value, *other_values) -> FloatT: ...
2335
@overload
2336
def minimum_of(value: DateT, other_value, *other_values) -> DateT: ...
2337
def minimum_of(value, other_value, *other_values):
2338
    """
2339
    Return the minimum value of a collection of Series or Values, disregarding NULLs.
2340
2341
    Example usage:
2342
    ```python
2343
    ealiest_event_date = minimum_of(event_series_1.date, event_series_2.date, "2001-01-01")
2344
    ```
2345
    """
2346
    args = cast_all_arguments((value, other_value, *other_values))
2347
    return _apply(qm.Function.MinimumOf, args)
2348
2349
2350
# ERROR HANDLING
2351
#
2352
2353
2354
def raise_helpful_error_if_possible(arg):
2355
    if isinstance(arg, BaseFrame):
2356
        raise TypeError(
2357
            f"Expecting a series but got a frame (`{arg.__class__.__name__}`): "
2358
            f"are you missing a column name?"
2359
        )
2360
    if callable(arg):
2361
        raise TypeError(
2362
            f"Function referenced but not called: are you missing parentheses on "
2363
            f"`{arg.__name__}()`?"
2364
        )
2365
    if isinstance(arg, when):
2366
        raise TypeError(
2367
            "Missing `.then(...).otherwise(...)` conditions on a `when(...)` expression"
2368
        )
2369
    if isinstance(arg, WhenThen):
2370
        raise TypeError(
2371
            "Missing `.otherwise(...)` condition on a `when(...).then(...)` expression\n"
2372
            "Note: you can use `.otherwise(None)` to get NULL values"
2373
        )
2374
2375
2376
def validate_ehrql_series(arg, context):
2377
    try:
2378
        raise_helpful_error_if_possible(arg)
2379
    except TypeError as e:
2380
        raise TypeError(f"invalid {context}:\n{e})") from None
2381
    if not isinstance(arg, BaseSeries):
2382
        raise TypeError(
2383
            f"invalid {context}:\n"
2384
            f"Expecting an ehrQL series, got type '{type(arg).__qualname__}'"
2385
        )
2386
2387
2388
def validate_patient_series(arg, context):
2389
    validate_ehrql_series(arg, context)
2390
    if not isinstance(arg, PatientSeries):
2391
        raise TypeError(
2392
            f"invalid {context}:\nExpecting a series with only one value per patient"
2393
        )
2394
2395
2396
def validate_patient_series_type(arg, types, context):
2397
    validate_patient_series(arg, context)
2398
    if arg._type not in types:
2399
        types_desc = humanize_list_of_types(types)
2400
        article = "an" if types_desc[0] in "aeiou" else "a"
2401
        raise TypeError(
2402
            f"invalid {context}:\n"
2403
            f"Expecting {article} {types_desc} series, got series of type"
2404
            f" '{arg._type.__qualname__}'",
2405
        )
2406
2407
2408
HUMAN_TYPES = {
2409
    bool: "boolean",
2410
    int: "integer",
2411
}
2412
2413
2414
def humanize_list_of_types(types):
2415
    type_names = [HUMAN_TYPES.get(type_, type_.__qualname__) for type_ in types]
2416
    initial = ", ".join(type_names[:-1])
2417
    return f"{initial} or {type_names[-1]}" if initial else type_names[-1]
2418
2419
2420
def modify_exception(exc):
2421
    # This is our chance to modify exceptions which we didn't raise ourselves to make
2422
    # them more helpful or add additional context
2423
    if operator := _get_operator_error(exc):
2424
        exc.add_note(
2425
            _format_operator_error_note(operator),
2426
        )
2427
    return exc
2428
2429
2430
def _get_operator_error(exc):
2431
    # Because `and`, `or` and `not` are control-flow primitives in Python they are not
2432
    # overridable and so we're forced to use the bitwise operators for logical
2433
    # operations. However these have different precedence rules from those governing the
2434
    # standard operators and so it's easy to accidentally do the wrong thing. Here we
2435
    # identify errors associated with the logical operators so we can add a note trying
2436
    # to explain what might have happened.
2437
    if not isinstance(exc, TypeError):
2438
        return
2439
    # Sadly we have to do this via string matching on the exception text
2440
    if match := re.match(
2441
        r"(unsupported operand type\(s\) for|bad operand type for unary) ([|&~]):",
2442
        str(exc),
2443
    ):
2444
        return match.group(2)
2445
2446
2447
def _format_operator_error_note(operator):
2448
    if operator == "|":
2449
        example_bad = "a == b | x == y"
2450
        example_good = "(a == b) | (x == y)"
2451
    elif operator == "&":
2452
        example_bad = "a == b & x == y"
2453
        example_good = "(a == b) & (x == y)"
2454
    elif operator == "~":
2455
        example_bad = "~ a == b"
2456
        example_good = "~ (a == b)"
2457
    else:
2458
        assert False
2459
    return (
2460
        f"\n"
2461
        f"WARNING: The `{operator}` operator has surprising precedence rules, meaning\n"
2462
        "you may need to add more parentheses to get the correct behaviour.\n"
2463
        f"\n"
2464
        f"For example, instead of writing:\n"
2465
        f"\n"
2466
        f"    {example_bad}\n"
2467
        f"\n"
2468
        f"You should write:\n"
2469
        f"\n"
2470
        f"    {example_good}"
2471
    )