--- a +++ b/tests/spec/population/test_population.py @@ -0,0 +1,97 @@ +from ehrql import case, when + +from ..tables import e, p + + +title = "Defining a population" +text = """ +`define_population` is used to limit the population from which data is extracted. +""" + + +def test_population_with_single_table(spec_test): + """ + Extract a column from a patient table after limiting the population by another column. + """ + table_data = { + p: """ + | b1 | i1 + --+----+--- + 1 | F | 10 + 2 | T | 20 + 3 | F | 30 + """, + } + + spec_test( + table_data, + p.i1, + { + 1: 10, + 3: 30, + }, + population=~p.b1, + ) + + +def test_population_with_multiple_tables(spec_test): + """ + Limit the patient population by a column in one table, and return values from another + table. + """ + table_data = { + p: """ + | i1 + --+---- + 1 | 10 + 2 | 20 + 3 | 0 + """, + e: """ + | i1 + --+----- + 1 | 101 + 1 | 102 + 3 | 301 + 4 | 401 + """, + } + + spec_test( + table_data, + e.exists_for_patient(), + { + 1: True, + 2: False, + }, + population=p.i1 > 0, + ) + + +def test_case_with_case_expression(spec_test): + """ + Limit the patient population by a case expression. + """ + table_data = { + p: """ + | i1 + --+--- + 1 | 6 + 2 | 7 + 3 | 9 + 4 | + """, + } + + spec_test( + table_data, + p.i1, + { + 1: 6, + 2: 7, + }, + population=case( + when(p.i1 <= 8).then(True), + when(p.i1 > 8).then(False), + ), + )