[e988c2]: / tests / unit / query_engines / test_mssql_dialect.py

Download this file

151 lines (129 with data), 5.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import datetime
import sqlalchemy
from sqlalchemy.sql.visitors import iterate, replacement_traverse
from ehrql.query_engines.mssql_dialect import (
MSSQLDialect,
ScalarSelectAggregation,
SelectStarInto,
)
def test_mssql_date_types():
# Note: it would be nice to parameterize this test, but given that the
# inputs are SQLAlchemy expressions I don't know how to do this without
# constructing the column objects outside of the test, which I don't really
# want to do.
date_col = sqlalchemy.Column("date_col", sqlalchemy.Date())
datetime_col = sqlalchemy.Column("datetime_col", sqlalchemy.DateTime())
assert (
_str(date_col == datetime.date(2021, 5, 15))
== "date_col = CAST('20210515' AS DATE)"
)
assert (
_str(datetime_col == datetime.datetime(2021, 5, 15, 9, 10, 0))
== "datetime_col = CAST('2021-05-15T09:10:00' AS DATETIME)"
)
assert _str(date_col == None) == "date_col IS NULL" # noqa: E711
assert _str(datetime_col == None) == "datetime_col IS NULL" # noqa: E711
def test_casts_to_date():
# This fails unless our MSSQL dialect sets the appropriate minimum server version.
# By default it will treat DATE as an alias for DATETIME. See:
# https://github.com/sqlalchemy/sqlalchemy/blob/rel_1_4_46/lib/sqlalchemy/dialects/mssql/base.py#L1623-L1627
table = sqlalchemy.table("foo", sqlalchemy.Column("bar"))
clause = sqlalchemy.cast(table.c.bar, sqlalchemy.Date)
compiled = str(clause.compile(dialect=MSSQLDialect()))
assert compiled == "CAST(foo.bar AS DATE)"
def test_select_star_into():
table = sqlalchemy.table("foo", sqlalchemy.Column("bar"))
query = sqlalchemy.select(table.c.bar).where(table.c.bar > 1)
target_table = sqlalchemy.table("test")
select_into = SelectStarInto(target_table, query.alias())
assert _str(select_into) == (
"SELECT * INTO test FROM (SELECT foo.bar AS bar \n"
"FROM foo \n"
"WHERE foo.bar > 1) AS anon_1"
)
def test_select_star_into_can_be_iterated():
# If we don't define the `get_children()` method on `SelectStarInto` we won't get an
# error when attempting to iterate the resulting element structure: it will just act
# as a leaf node. But as we rely heavily on query introspection we need to ensure we
# can iterate over query structures.
table = sqlalchemy.table("foo", sqlalchemy.Column("bar"))
query = sqlalchemy.select(table.c.bar).where(table.c.bar > 1)
target_table = sqlalchemy.table("test")
select_into = SelectStarInto(target_table, query.alias())
# Check that SelectStarInto supports iteration by confirming that we can get back to
# both the target table and the original table by iterating it
assert any([e is table for e in iterate(select_into)]), "no `table`"
assert any([e is target_table for e in iterate(select_into)]), "no `target_table`"
def _str(expression):
compiled = expression.compile(
dialect=MSSQLDialect(),
compile_kwargs={"literal_binds": True},
)
return str(compiled).strip()
def test_mssql_float_type():
float_col = sqlalchemy.Column("float_col", sqlalchemy.Float())
# explicitly casts floats
assert _str(float_col == 0.75) == "float_col = CAST(0.75 AS FLOAT)"
assert _str(float_col == None) == "float_col IS NULL" # noqa: E711
assert (
_str(sqlalchemy.sql.case((float_col > 0.5, 0.1), else_=0.75))
== "CASE WHEN (float_col > CAST(0.5 AS FLOAT)) THEN CAST(0.1 AS FLOAT) ELSE CAST(0.75 AS FLOAT) END"
)
def test_scalar_select_aggregation():
table = sqlalchemy.table(
"t1",
sqlalchemy.Column("c1"),
sqlalchemy.Column("c2"),
sqlalchemy.Column("c3"),
)
maximum = ScalarSelectAggregation.build(
sqlalchemy.func.max, [table.columns.c1, table.columns.c2]
)
query = sqlalchemy.select(table.columns.c3).where(maximum == 1)
assert _str(query) == (
"SELECT t1.c3 \n"
"FROM t1 \n"
"WHERE ("
"SELECT max(aggregate_values.value) AS max_1 \n"
"FROM (VALUES (t1.c1), (t1.c2)) AS aggregate_values (value)"
") = 1"
)
def test_scalar_select_aggregation_can_be_iterated():
table = sqlalchemy.table(
"t1",
sqlalchemy.Column("c1"),
sqlalchemy.Column("c2"),
sqlalchemy.Column("c3"),
)
maximum = ScalarSelectAggregation.build(
sqlalchemy.func.max, [table.columns.c1, table.columns.c2]
)
query = sqlalchemy.select(table.columns.c3).where(maximum == 1)
# Check that iterating the resulting query gets us back to the original columns
iterator_elements = iterate(query)
expected = [table.columns.c1, table.columns.c2, table.columns.c3]
# We have to compare object IDs here because these objects overload `__eq__`
assert {id(el) for el in expected} <= {id(el) for el in iterator_elements}
def test_scalar_select_aggregation_supports_replacement_traverse():
table = sqlalchemy.table(
"t1",
sqlalchemy.Column("c1"),
sqlalchemy.Column("c2"),
sqlalchemy.Column("c3"),
)
query = ScalarSelectAggregation.build(
sqlalchemy.func.max, [table.columns.c1, table.columns.c2]
)
new_query = replacement_traverse(
query,
{},
lambda obj: table.columns.c3 if obj is table.columns.c2 else None,
)
assert _str(query) == (
"(SELECT max(aggregate_values.value) AS max_1 \n"
"FROM (VALUES (t1.c1), (t1.c2)) AS aggregate_values (value))"
)
assert _str(new_query) == (
"(SELECT max(aggregate_values.value) AS max_1 \n"
"FROM (VALUES (t1.c1), (t1.c3)) AS aggregate_values (value))"
)