--- a +++ b/tests/unit/test_quiz.py @@ -0,0 +1,384 @@ +from pathlib import Path +from unittest.mock import patch + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +from ehrql import debugger, quiz, weeks +from ehrql.query_engines.debug import DebugQueryEngine +from ehrql.query_language import Dataset +from ehrql.tables.core import ( + clinical_events, + medications, + patients, + practice_registrations, +) + + +def get_engine(): # Used for hypothesis test + path = Path(__file__).parents[1] / "fixtures" / "quiz-example-data" + return DebugQueryEngine(str(path)) + + +@pytest.fixture +def engine(): + return get_engine() + + +def dataset_smoketest( + index_year: int = 2022, + min_age: int = 18, + max_age: int = 80, + year_of_birth_column: bool = False, +) -> Dataset: + year_of_birth = patients.date_of_birth.year + age = index_year - year_of_birth + + dataset = Dataset() + dataset.define_population((age >= min_age) & (age <= max_age)) + dataset.age = age + if year_of_birth_column: + dataset.year_of_birth = patients.date_of_birth.year + return dataset + + +def filtered_medications( + index_year: int = 2023, + interval_weeks: int = 52, + codelist: str | list[str] = ["39113611000001102"], + filter_dates: bool = True, +): + index_date = f"{index_year}-01-01" + if isinstance(codelist, str): + codelist = [codelist] + f = medications.dmd_code.is_in(codelist) + if filter_dates: + f = f & medications.date.is_on_or_between( + index_date - weeks(interval_weeks), index_date + ) + filtered = medications.where(f) + return filtered + + +# Tests for check_answer + + +@pytest.mark.parametrize( + "answer, expected, message", + [ + (1, Dataset(), "Expected Dataset, got int instead."), + (patients, Dataset(), "Expected Dataset, got Table instead."), + (patients.date_of_birth, Dataset(), "Expected Dataset, got Series instead."), + (Dataset(), patients, "Expected Table, got Dataset instead."), + (patients, patients.date_of_birth, "Expected Series, got Table instead."), + ], +) +def test_check_answer_wrong_type_before_evaluation(answer, expected, message): + msg = quiz.check_answer(engine=None, answer=answer, expected=expected) + assert msg == message + + +def test_check_answer_event_frame_not_converted_to_patient_frame(engine): + # Wrong type after evaluation + msg = quiz.check_answer( + engine=engine, + answer=clinical_events, + expected=clinical_events.sort_by(clinical_events.date).last_for_patient(), + ) + assert msg == "Expected PatientTable, got EventTable instead." + + +@pytest.mark.parametrize( + "answer, expected", + [ + (Dataset(), Dataset()), # Same syntax but different objects + (dataset_smoketest(), dataset_smoketest()), + (patients, patients), + (patients.date_of_birth, patients.date_of_birth), + (clinical_events, clinical_events), + (medications.dmd_code, medications.dmd_code), + ], +) +def test_check_answer_same_syntax_correct(engine, answer, expected): + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == "Correct!" + + +def test_check_answer_empty_dataset(engine): + msg = quiz.check_answer( + engine=engine, answer=Dataset(), expected=dataset_smoketest() + ) + assert msg == "The dataset is empty." + + +@pytest.mark.parametrize( + "order, message", + [ + ([0, 1], "Missing column(s): year_of_birth."), + ([1, 0], "Found extra column(s): year_of_birth."), + ], +) +def test_check_answer_dataset_has_missing_or_extra_column(engine, order, message): + datasets = [dataset_smoketest(), dataset_smoketest(year_of_birth_column=True)] + answer, expected = (datasets[i] for i in order) + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == message + + +def test_check_answer_dataset_typo_in_column_name(engine): + answer = dataset_smoketest(year_of_birth_column=False) + answer.yeah_of_birth = patients.date_of_birth.year + expected = dataset_smoketest(year_of_birth_column=True) + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert ( + msg + == "Missing column(s): year_of_birth.\nFound extra column(s): yeah_of_birth." + ) + + +@pytest.mark.parametrize( + "order, message", + [ + ([0, 1], "Missing patient(s): 4, 5, 9."), + ([1, 0], "Found extra patient(s): 4, 5, 9."), + ], +) +def test_check_answer_dataset_has_missing_or_extra_patients(engine, order, message): + datasets = [dataset_smoketest(), dataset_smoketest(min_age=0, max_age=100)] + answer, expected = (datasets[i] for i in order) + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == message + + +def test_check_answer_dataset_column_has_missing_patients(engine): + msg = quiz.check_answer( + engine=engine, + answer=dataset_smoketest(index_year=2023), + expected=dataset_smoketest(), + ) + assert msg == "Incorrect `age` value for patient 1: expected 49, got 50 instead." + + +@pytest.mark.parametrize( + "order, message", + [ + ([0, 1], "Missing patient(s): 7."), + ([1, 0], "Found extra patient(s): 7."), + ], +) +def test_check_answer_patient_series_has_missing_or_extra_patients( + engine, order, message +): + series = [ + practice_registrations.for_patient_on("2013-12-01").practice_pseudo_id, + practice_registrations.for_patient_on("2014-01-01").practice_pseudo_id, + ] + answer, expected = (series[i] for i in order) + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == message + + +def test_check_answer_patient_series_has_incorrect_value(engine): + msg = quiz.check_answer( + engine=engine, + answer=patients.age_on("2023-12-31"), + expected=patients.age_on("2022-12-31"), + ) + assert msg == "Incorrect value for patient 1: expected 49, got 50 instead." + + +@pytest.mark.parametrize( + "order, message", + [ + ([1, 0], "Missing row(s): 1, 3, 6, 9, 10, 13, 15, 17, 19, 20."), + ([0, 1], "Found extra row(s): 1, 3, 6, 9, 10, 13, 15, 17, 19, 20."), + ], +) +def test_check_answer_event_table_has_missing_or_extra_rows(engine, order, message): + tables = [ + clinical_events, + clinical_events.where(clinical_events.snomedct_code.is_in(["60621009"])), + ] + answer, expected = (tables[i] for i in order) + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == message + + +@pytest.mark.parametrize( + "order, message", + [ + ([1, 0], "Missing row(s): 5, 9."), + ([0, 1], "Found extra row(s): 5, 9."), + ], +) +def test_check_answer_event_series_has_missing_or_extra_rows(engine, order, message): + series = [ + medications.dmd_code, + medications.where(medications.date.is_on_or_before("2020-12-01")).dmd_code, + ] + answer, expected = (series[i] for i in order) + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == message + + +def test_check_answer_event_series_has_incorrect_value(engine): + answer = medications.date + expected = medications.dmd_code + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert ( + msg + == "Incorrect value for patient 1, row 1: expected 39113611000001102, got 2014-01-11 instead." + ) + + +def test_check_answer_incorrect_event_selection_for_patient(engine): + events = clinical_events.where( + clinical_events.snomedct_code.is_in(["60621009"]) + ).sort_by(clinical_events.date) + answer = events.first_for_patient() + expected = events.last_for_patient() + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert ( + msg + == "Incorrect `numeric_value` value for patient 2: expected 23.1, got 18.4 instead." + ) + + +def test_check_answer_patient_series_has_incorrect_default(engine): + date_1 = ( + clinical_events.where(clinical_events.snomedct_code.is_in(["60621009"])) + .sort_by(clinical_events.date) + .last_for_patient() + .date + ) + date_2 = ( + clinical_events.where(clinical_events.snomedct_code.is_in(["60621010"])) + .sort_by(clinical_events.date) + .last_for_patient() + .date + ) + answer = (date_1 > date_2) | (date_1.is_null() & date_2.is_null()) + expected = date_1 > date_2 + msg = quiz.check_answer(engine=engine, answer=answer, expected=expected) + assert msg == ( + "Series has the wrong default value for patients with no matching records: " + "expected None but got True" + ) + + +def test_check_answer_unidentified_error_shows_fallback_message(engine): + with patch("ehrql.quiz.check_patient_table_values", return_value=None): + msg = quiz.check_answer( + engine, dataset_smoketest(index_year=2024), dataset_smoketest() + ) + assert msg.startswith("Incorrect answer.\nExpected:") + + +# Hypothesis tests for error message coverage +# A generated answer should be either correct or gives an informative error +# Generate some answers and assert that the fall-back message is not produced + + +@given( + dataset=st.builds( + dataset_smoketest, + index_year=..., + min_age=..., + max_age=..., + year_of_birth_column=..., + ) +) +def test_check_answer_dataset_is_either_correct_or_has_informative_error(dataset): + engine = get_engine() + msg = quiz.check_answer(engine, dataset, dataset_smoketest()) + assert not msg.startswith("Incorrect answer.\nExpected:") + + +@given( + answer=st.builds( + filtered_medications, + index_year=st.integers(min_value=1900, max_value=2100), + interval_weeks=st.integers(min_value=0, max_value=52), + codelist=st.from_regex(r"[1-9][0-9]{5,17}", fullmatch=True), + filter_dates=..., + ) +) +def test_check_answer_filtered_medications_is_either_correct_or_has_informative_error( + answer, +): + engine = get_engine() + msg = quiz.check_answer( + engine, + answer, + filtered_medications(), + ) + assert not msg.startswith("Incorrect answer.\nExpected:") + + +@pytest.mark.parametrize( + "answer,message", + [ + (Dataset(), "Correct!"), + (..., "Skipped."), + ], +) +def test_check(capfd, answer, message): + question = quiz.Question("Create an Empty Dataset.", 0) + question.expected = Dataset() + question.check(answer) + assert capfd.readouterr().err.rstrip() == f"Question 0\n{message}" + + +def test_summarise(capfd): + questions = quiz.Questions() + questions[1] = quiz.Question("Q1") + questions[2] = quiz.Question("Q2") + questions.summarise() + assert capfd.readouterr().err.rstrip() == "\n".join( + [ + "\n\nSummary of your results", + "Correct: 0", + "Incorrect: 0", + "Unanswered: 2", + ] + ) + + +def test_questions(): + questions = quiz.Questions() + questions.set_dummy_tables_path("test_dummy_path") + questions[1] = quiz.Question("Q1") + questions[2] = quiz.Question("Q2") + assert len(list(questions.get_all())) == 2 + assert questions[1].index == 1 + assert questions[2].engine.dsn.name == "test_dummy_path" + + +def test_set_dummy_tables_path_in_debug_context(): + with debugger.activate_debug_context( + dummy_tables_path="foo", render_function=lambda v: v + ): + questions = quiz.Questions() + questions.set_dummy_tables_path("bar") + assert debugger.DEBUG_CONTEXT.query_engine.dsn.name == "bar" + # This should be unset outside of the context manager + assert debugger.DEBUG_CONTEXT is None + + +def test_hint(capfd): + hint = "This is a\nhint" + question = quiz.Question("", 0) + question._hint = hint + question.hint() + assert capfd.readouterr().err.rstrip().startswith(f"Hint for question 0:\n\n{hint}") + + +def test_empty_hint(capfd): + question = quiz.Question("", 0) + question.hint() + # If the hint is not provided there is a default hint message that currently starts + # with "Remember". + assert ( + capfd.readouterr().err.rstrip().startswith("Hint for question 0:\n\nRemember") + )