[e988c2]: / tests / unit / dummy_data / test_query_info.py

Download this file

154 lines (130 with data), 4.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import datetime
from ehrql import Dataset, days
from ehrql.codes import CTV3Code
from ehrql.dummy_data.query_info import ColumnInfo, QueryInfo, TableInfo
from ehrql.tables import (
Constraint,
EventFrame,
PatientFrame,
Series,
table,
table_from_rows,
)
@table
class patients(PatientFrame):
date_of_birth = Series(datetime.date)
sex = Series(
str,
constraints=[
Constraint.Categorical(
["male", "female", "intersex"],
)
],
)
@table
class events(EventFrame):
date = Series(datetime.date)
code = Series(CTV3Code)
def test_query_info_from_dataset():
dataset = Dataset()
dataset.define_population(events.exists_for_patient())
dataset.date_of_birth = patients.date_of_birth
dataset.sex = patients.sex
dataset.has_event = events.where(
events.code == CTV3Code("abc00")
).exists_for_patient()
query_info = QueryInfo.from_dataset(dataset._compile())
assert query_info == QueryInfo(
tables={
"events": TableInfo(
name="events",
has_one_row_per_patient=False,
columns={
"code": ColumnInfo(
name="code",
type=str,
_values_used={"abc00"},
)
},
),
"patients": TableInfo(
name="patients",
has_one_row_per_patient=True,
columns={
"date_of_birth": ColumnInfo(
name="date_of_birth",
type=datetime.date,
),
"sex": ColumnInfo(
name="sex",
type=str,
constraints=(
Constraint.Categorical(
values=("male", "female", "intersex")
),
),
),
},
),
},
population_table_names=["events"],
other_table_names=["patients"],
)
def test_query_info_records_values():
@table
class test_table(PatientFrame):
value = Series(str)
dataset = Dataset()
dataset.define_population(test_table.exists_for_patient())
dataset.q1 = (
# NOTE: If we add examples here we should add the same examples to the inline
# patient table test below so we can check they are correctly handled in that
# context
(test_table.value == "a")
| test_table.value.is_in(["b", "c"])
| test_table.value.contains("d")
)
query_info = QueryInfo.from_dataset(dataset._compile())
column_info = query_info.tables["test_table"].columns["value"]
assert column_info == ColumnInfo(
name="value",
type=str,
_values_used={"a", "b", "c", "d"},
)
def test_query_info_ignores_inline_patient_tables():
# InlinePatientTable nodes are unusual from the point of view of dummy data because
# they come bundled with their own data (presumably based on dummy data generated
# further upstream in the data processing pipeline) so we don't need to generate any
# for them. This means that the QueryInfo class can, and should, ignore them.
@table_from_rows([])
class inline_table(PatientFrame):
value = Series(str)
dataset = Dataset()
dataset.define_population(events.exists_for_patient())
dataset.q1 = (
(inline_table.value == "a")
| inline_table.value.is_in(["b", "c"])
| inline_table.value.contains("d")
)
query_info = QueryInfo.from_dataset(dataset._compile())
assert query_info == QueryInfo(
tables={
"events": TableInfo(
name="events", has_one_row_per_patient=False, columns={}
)
},
population_table_names=["events"],
other_table_names=[],
)
def test_query_info_ignores_complex_comparisons():
# By "complex" here we just mean anything other than a direct comparison between a
# selected column and a static value. We don't attempt to handle these, but we want
# to make sure we don't blow up with an error, or misinterpret them.
dataset = Dataset()
dataset.define_population(events.exists_for_patient())
dataset.q1 = patients.date_of_birth.year.is_in([2000, 2010, 2020])
dataset.q2 = patients.date_of_birth + days(100) == "2021-10-20"
dataset.q3 = patients.date_of_birth == "2022-10-05"
query_info = QueryInfo.from_dataset(dataset._compile())
column_info = query_info.tables["patients"].columns["date_of_birth"]
assert column_info.values_used == [datetime.date(2022, 10, 5)]