a b/tests/unit/test_renderers.py
1
import textwrap
2
from itertools import product
3
4
import pytest
5
6
from ehrql.query_engines.in_memory_database import (
7
    PatientColumn,
8
    PatientTable,
9
)
10
from ehrql.renderers import DISPLAY_RENDERERS
11
12
13
TABLE = PatientTable.parse(
14
    """
15
      |  i1 |  i2
16
    --+-----+-----
17
    1 | 101 | 111
18
    2 | 201 | 211
19
    3 | 301 | 311
20
    4 | 401 | 411
21
    5 | 501 | 511
22
    """
23
)
24
25
26
@pytest.mark.parametrize("render_format", ["ascii", "html"])
27
def test_render_table(render_format):
28
    expected_output = {
29
        "ascii": textwrap.dedent(
30
            """
31
            patient_id        | i1                | i2
32
            ------------------+-------------------+------------------
33
            1                 | 101               | 111
34
            2                 | 201               | 211
35
            3                 | 301               | 311
36
            4                 | 401               | 411
37
            5                 | 501               | 511
38
            """
39
        ).strip(),
40
        "html": (
41
            "<!-- start debug output -->"
42
            "<table>"
43
            "<thead>"
44
            "<tr><th>patient_id</th><th>i1</th><th>i2</th></tr>"
45
            "</thead>"
46
            "<tbody>"
47
            "<tr><td>1</td><td>101</td><td>111</td></tr>"
48
            "<tr><td>2</td><td>201</td><td>211</td></tr>"
49
            "<tr><td>3</td><td>301</td><td>311</td></tr>"
50
            "<tr><td>4</td><td>401</td><td>411</td></tr>"
51
            "<tr><td>5</td><td>501</td><td>511</td></tr>"
52
            "</tbody>"
53
            "</table>"
54
            "<!-- end debug output -->"
55
        ),
56
    }
57
    rendered = DISPLAY_RENDERERS[render_format](list(TABLE.to_records())).strip()
58
    assert rendered == expected_output[render_format], rendered
59
60
61
@pytest.mark.parametrize("render_format", ["ascii", "html"])
62
def test_render_column(render_format):
63
    expected_output = expected_output = {
64
        "ascii": textwrap.dedent(
65
            """
66
            patient_id        | value
67
            ------------------+------------------
68
            1                 | 101
69
            2                 | 201
70
            """
71
        ).strip(),
72
        "html": (
73
            "<!-- start debug output -->"
74
            "<table>"
75
            "<thead>"
76
            "<tr><th>patient_id</th><th>value</th></tr>"
77
            "</thead>"
78
            "<tbody>"
79
            "<tr><td>1</td><td>101</td></tr>"
80
            "<tr><td>2</td><td>201</td></tr>"
81
            "</tbody>"
82
            "</table>"
83
            "<!-- end debug output -->"
84
        ),
85
    }
86
87
    c = PatientColumn.parse(
88
        """
89
        1 | 101
90
        2 | 201
91
        """
92
    )
93
    rendered = DISPLAY_RENDERERS[render_format](list(c.to_records())).strip()
94
    assert rendered == expected_output[render_format], rendered
95
96
97
@pytest.mark.parametrize("render_format", ["ascii", "html"])
98
def test_render_table_head(render_format):
99
    expected_output = {
100
        "ascii": textwrap.dedent(
101
            """
102
            patient_id        | i1                | i2
103
            ------------------+-------------------+------------------
104
            1                 | 101               | 111
105
            2                 | 201               | 211
106
            ...               | ...               | ...
107
            """
108
        ).strip(),
109
        "html": (
110
            "<!-- start debug output -->"
111
            "<table>"
112
            "<thead>"
113
            "<tr><th>patient_id</th><th>i1</th><th>i2</th></tr>"
114
            "</thead>"
115
            "<tbody>"
116
            "<tr><td>1</td><td>101</td><td>111</td></tr>"
117
            "<tr><td>2</td><td>201</td><td>211</td></tr>"
118
            "<tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr>"
119
            "</tbody>"
120
            "</table>"
121
            "<!-- end debug output -->"
122
        ),
123
    }
124
125
    truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), head=2)
126
    assert truncated == expected_output[render_format], truncated
127
128
129
@pytest.mark.parametrize("render_format", ["ascii", "html"])
130
def test_render_table_tail(render_format):
131
    expected_output = {
132
        "ascii": textwrap.dedent(
133
            """
134
            patient_id        | i1                | i2
135
            ------------------+-------------------+------------------
136
            ...               | ...               | ...
137
            4                 | 401               | 411
138
            5                 | 501               | 511
139
            """
140
        ).strip(),
141
        "html": (
142
            "<!-- start debug output -->"
143
            "<table>"
144
            "<thead>"
145
            "<tr><th>patient_id</th><th>i1</th><th>i2</th></tr>"
146
            "</thead>"
147
            "<tbody>"
148
            "<tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr>"
149
            "<tr><td>4</td><td>401</td><td>411</td></tr>"
150
            "<tr><td>5</td><td>501</td><td>511</td></tr>"
151
            "</tbody>"
152
            "</table>"
153
            "<!-- end debug output -->"
154
        ),
155
    }
156
157
    truncated = DISPLAY_RENDERERS[render_format](list(TABLE.to_records()), tail=2)
158
    assert truncated == expected_output[render_format], truncated
159
160
161
@pytest.mark.parametrize("render_format", ["ascii", "html"])
162
def test_render_table_head_and_tail(render_format):
163
    expected_output = {
164
        "ascii": textwrap.dedent(
165
            """
166
            patient_id        | i1                | i2
167
            ------------------+-------------------+------------------
168
            1                 | 101               | 111
169
            2                 | 201               | 211
170
            ...               | ...               | ...
171
            4                 | 401               | 411
172
            5                 | 501               | 511
173
            """
174
        ).strip(),
175
        "html": (
176
            "<!-- start debug output -->"
177
            "<table>"
178
            "<thead>"
179
            "<tr><th>patient_id</th><th>i1</th><th>i2</th></tr>"
180
            "</thead>"
181
            "<tbody>"
182
            "<tr><td>1</td><td>101</td><td>111</td></tr>"
183
            "<tr><td>2</td><td>201</td><td>211</td></tr>"
184
            "<tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr>"
185
            "<tr><td>4</td><td>401</td><td>411</td></tr>"
186
            "<tr><td>5</td><td>501</td><td>511</td></tr>"
187
            "</tbody>"
188
            "</table>"
189
            "<!-- end debug output -->"
190
        ),
191
    }
192
193
    truncated = DISPLAY_RENDERERS[render_format](
194
        list(TABLE.to_records()), head=2, tail=2
195
    )
196
    assert truncated == expected_output[render_format], truncated
197
198
199
@pytest.mark.parametrize(
200
    "render_format,head_tail",
201
    list(product(["ascii"], [(0, 0), (2, 3), (5, 0), (0, 6), (3, 3)])),
202
)
203
def test_render_table_bad_head_tail(render_format, head_tail):
204
    expected_output = {
205
        "ascii": textwrap.dedent(
206
            """
207
            patient_id        | i1                | i2
208
            ------------------+-------------------+------------------
209
            1                 | 101               | 111
210
            2                 | 201               | 211
211
            3                 | 301               | 311
212
            4                 | 401               | 411
213
            5                 | 501               | 511
214
            """
215
        ).strip(),
216
        "html": (
217
            "<!-- start debug output -->"
218
            "<table>"
219
            "<thead>"
220
            "<tr><th>patient_id</th><th>i1</th><th>i2</th></tr>"
221
            "</thead>"
222
            "<tbody>"
223
            "<tr><td>1</td><td>101</td><td>111</td></tr>"
224
            "<tr><td>2</td><td>201</td><td>211</td></tr>"
225
            "<tr><td>3</td><td>301</td><td>311</td></tr>"
226
            "<tr><td>4</td><td>401</td><td>411</td></tr>"
227
            "<tr><td>5</td><td>501</td><td>511</td></tr>"
228
            "</tbody>"
229
            "</table>"
230
            "<!-- end debug output -->"
231
        ),
232
    }
233
    head, tail = head_tail
234
    truncated = DISPLAY_RENDERERS[render_format](
235
        list(TABLE.to_records()), head=head, tail=tail
236
    )
237
    assert truncated == expected_output[render_format], (truncated, head, tail)