[e988c2]: / tests / unit / query_model / test_column_specs.py

Download this file

140 lines (116 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import datetime
from ehrql.codes import SNOMEDCTCode
from ehrql.query_model.column_specs import (
ColumnSpec,
get_categories,
get_range,
get_table_specs,
)
from ehrql.query_model.nodes import (
AggregateByPatient,
Case,
Column,
Constraint,
Dataset,
Function,
SelectColumn,
SelectPatientTable,
SelectTable,
TableSchema,
Value,
)
def test_get_column_specs():
patients = SelectPatientTable(
"patients",
schema=TableSchema(
date_of_birth=Column(datetime.date),
code=Column(SNOMEDCTCode),
category=Column(
SNOMEDCTCode,
constraints=[
Constraint.Categorical(
[SNOMEDCTCode("123000"), SNOMEDCTCode("456000")]
)
],
),
),
)
dataset = Dataset(
population=AggregateByPatient.Exists(patients),
variables=dict(
dob=SelectColumn(patients, "date_of_birth"),
code=SelectColumn(patients, "code"),
category=SelectColumn(patients, "category"),
),
events={},
measures=None,
)
table_specs = get_table_specs(dataset)
assert table_specs == {
"dataset": {
"patient_id": ColumnSpec(type=int, nullable=False, categories=None),
"dob": ColumnSpec(type=datetime.date, nullable=True, categories=None),
"code": ColumnSpec(type=str, nullable=True, categories=None),
"category": ColumnSpec(
type=str, nullable=True, categories=("123000", "456000")
),
}
}
events = SelectTable(
"events",
schema=TableSchema(
event_type=Column(str, constraints=[Constraint.Categorical(["a", "b", "c"])]),
event_name=Column(str),
),
)
event_type = SelectColumn(events, "event_type")
event_name = SelectColumn(events, "event_name")
def test_get_categories_default_implementation():
assert get_categories(AggregateByPatient.Exists(events)) is None
def test_get_categories_for_select_column():
assert get_categories(event_type) == ("a", "b", "c")
def test_get_categories_for_min_max():
assert get_categories(AggregateByPatient.Min(event_type)) == ("a", "b", "c")
assert get_categories(AggregateByPatient.Max(event_type)) == ("a", "b", "c")
def test_get_categories_for_case_with_static_values():
name_bucket = Case(
{
Function.LT(event_name, Value("abc")): Value("low"),
Function.LT(event_name, Value("lmn")): Value("med"),
Function.GE(event_name, Value("lmn")): Value("high"),
},
default=None,
)
assert get_categories(name_bucket) == ("low", "med", "high")
def test_get_categories_for_case_with_static_default():
type_or_missing = Case(
{Function.Not(Function.IsNull(event_type)): event_type},
default=Value("missing"),
)
assert get_categories(type_or_missing) == ("a", "b", "c", "missing")
def test_get_categories_for_case_with_mixed_categorical_and_noncategorical_values():
type_or_name = Case(
{Function.Not(Function.IsNull(event_type)): event_type},
default=event_name,
)
assert get_categories(type_or_name) is None
def test_get_range_default_implementation():
i = SelectColumn(SelectTable("t", schema=TableSchema(i=Column(int))), "i")
assert get_range(i) == (None, None)
def test_get_range_for_count():
count = AggregateByPatient.Count(events)
min_value, max_value = get_range(count)
assert min_value == 0
num_bits = len(f"{max_value:b}")
assert num_bits == 16
def test_get_range_for_select_column_with_range_constraint():
i = SelectColumn(
SelectTable(
"t",
schema=TableSchema(
i=Column(int, constraints=[Constraint.ClosedRange(0, 100, 10)])
),
),
"i",
)
assert get_range(i) == (0, 100)