--- a +++ b/tests/autocomplete/test_autocomplete.py @@ -0,0 +1,724 @@ +import ast +import inspect +import re +from enum import Enum +from io import BytesIO +from pathlib import Path +from tokenize import tokenize + +import pytest + +from ehrql import query_language, tables +from ehrql.query_language import ( + BaseSeries, + EventSeries, + PatientSeries, + days, + get_tables_from_namespace, + int_property, +) +from ehrql.tables import core, emis, tpp +from ehrql.utils.module_utils import get_submodules + +from .language_server import LanguageServer + + +# Using the pyright language server to test ehrQL autocompletion +# + +# Before we run the tests we first parse the autocomplete_definition.py file +# to pull out: +# - all statements (either methods or direct attribute references) +# - all assignments +# - the expected type for each, parsed from the structured comments +# - the line number and cursor position to hover to get the expected type + +autocomplete_filepath = Path(__file__).with_name("autocomplete_definition.py") +autocomplete_file = autocomplete_filepath.read_text() + +# We want to ensure that all lines in the autocomplete file are covered by tests +# so we keep track of those that aren't +autocomplete_uncovered_lines = set(range(0, len(autocomplete_file.split("\n")) + 1)) + +# Using the tokenize method we can parse all tokens, including comments, to +# find the structured comments containing the expected types. For each we +# record the line number that it occurs on, and the expected type. +expected_types = {} +for token_type, string, start, end, line in tokenize( + BytesIO(autocomplete_file.encode("utf-8")).readline +): + if ( + line.strip() == "" + or line.strip().startswith("#") + or line.strip().startswith("from") + or line.strip().startswith("import") + ): + # Comments and import statements can be ignored + autocomplete_uncovered_lines.discard(start[0]) + type_comment = re.search("^## type:(?P<expected_type>.+)$", string) + if type_comment: + expected_types[start[0]] = type_comment.group("expected_type").strip() + + +# Now we use the AST module to walk the trees in the autocomplete file. This +# allows us to: +# - find all statements and assignments +# - for all methods find the method name and the calling object +# - find the line_number and cursor_position for where to hover +# - find the line_number where we expect to see a structured comment with the +# expected type +class MyVisitor(ast.NodeVisitor): + """ + Traverst the AST to find all the method calls along + with the object they were called from + """ + + def __init__(self): + self.methods = [] + self.statements = [] + self.attributes = [] + + def update_coverage(self, node): + for line_number in range(node.lineno, node.end_lineno + 1): + autocomplete_uncovered_lines.discard(line_number) + + def get_expected_type(self, line_number, statement): + if line_number not in expected_types: # pragma: no cover + assert 0, ( + f"Line {line_number} of the autocomplete_definition.py file does not " + "have an expected type in a structured comment. You need to provide a " + "comment like `## type:<expected_type>` immediately after the statement.\n" + "If this is a multi line statement, then put it on the last line. Typically " + "this will be immediately after the closing ')'.\n" + "The statement without an expression is:\n\n" + f"{statement}\n" + ) + return expected_types[line_number] + + def process_attribute( + self, node, line_number, expected_type_line, statement, cursor_position + ): + attr_name = node.attr + calling_object = ast.unparse(node.value) + + self.attributes.append((calling_object, attr_name)) + + expected_type = self.get_expected_type(expected_type_line, statement) + self.statements.append((statement, line_number, cursor_position, expected_type)) + + def visit_Attribute(self, node): + # Occurs when a statement is just an attribute call without assignment + # e.g. patients.core.date_of_birth + self.update_coverage(node) + self.process_attribute( + node, + node.end_lineno, + node.end_lineno, + ast.unparse(node), + node.end_col_offset, + ) + + def visit(self, node): + if not isinstance( + node, + ast.Module # always first node, we ignore + | ast.ImportFrom # ignore import statements + | ast.alias # ignore alias statements + | ast.Expr # everything is an expression, we ignore because we are only interested in the expression type + | ast.Assign # dealt with by visit_Assign + | ast.Attribute # dealt with by visit_Attribute + | ast.Call, # dealt with by visit_Call + ): # pragma: no cover + assert 0, ( + f"Line {node.lineno} of autocomplete.definition.py contains an unexpected expression " + "and so won't be tested. See the comments in that file for what is a valid expression. " + "If you need to test the expression, then you will need to update this test file. The " + f"unexpected expression was:\n\n{ast.unparse(node)}\n" + ) + return super().visit(node) + + def process_function(self, node, statement, cursor_position): + if not isinstance(node.func, ast.Name): + # A function with a calling object e.g. patients.age_on() + method = node.func.attr + calling_object = ast.unparse(node.func.value) + else: + # A function imported directly from a module. So far this only + # happens for functions defined in the top level of query_language.py + method = node.func.id + calling_object = "query_language" + + self.methods.append((calling_object, method)) + + expected_type = self.get_expected_type(node.end_lineno, statement) + line_number = node.func.end_lineno + self.statements.append((statement, line_number, cursor_position, expected_type)) + + def visit_Call(self, node): + # Occurs when a statement is just a method call without assignment + # e.g. patients.core.date_of_birth.is_after("2024-01-01") + self.update_coverage(node) + self.process_function(node, ast.unparse(node), node.func.end_col_offset) + + def visit_Assign(self, node): + # Occurs when a statement is an assignment + # e.g. var_name = patients.core.date_of_birth.is_after("2024-01-01") + # We process the statement assigned to the variable + + if node.targets[0].id == "date_str": + # Ignore the line that just defines a variable `date_str` + self.update_coverage(node) + return + + if isinstance(node.value, ast.Attribute): # pragma: no cover + # Currently we don't actually do this, hence the pragma, but leaving + # it in so that potential future changes to the autocomplete_definition + # file can be handled + self.update_coverage(node) + self.process_attribute( + node.value, node.lineno, node.end_lineno, ast.unparse(node), 0 + ) + elif isinstance(node.value, ast.Call): + self.update_coverage(node) + self.process_function(node.value, ast.unparse(node), 0) + else: + # Sometimes we may want to test autocomplete for something that isn't + # an attribute or a method call. In which case this is fine with a variable + # assignment because the autocomplete can come from a hover at cursor + # position 0 (the start of the variable name) + self.update_coverage(node) + statement = ast.unparse(node.value) + expected_type = self.get_expected_type(node.end_lineno, statement) + self.statements.append((statement, node.lineno, 0, expected_type)) + + +visitor = MyVisitor() +visitor.visit(ast.parse(autocomplete_file)) +autocomplete_file_statements = visitor.statements +autocomplete_file_methods = visitor.methods +autocomplete_file_attributes = visitor.attributes + +assert len(autocomplete_uncovered_lines) == 0, ( + "The autocomplete_definition.py contains some lines that aren't covered by the tests:\n" + f"{autocomplete_uncovered_lines}\n" + "You have probably added something to it that the AST NodeVisitor isn't expecting." +) + + +@pytest.fixture(scope="session") +def language_server(tmp_path_factory): + temp_file_path = tmp_path_factory.mktemp("autoocomplete") / "autocomplete.py" + return LanguageServer(temp_file_path) + + +class CompletionItemKind(Enum): + """ + The different kinds of things the language server protocol supports. + Currently we only have METHOD, VARIABLE and PROPERTY but the full list is here: + https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#completionItemKind + """ + + METHOD = 2 + VARIABLE = 6 + PROPERTY = 10 + + +def get_completion_item_kind(table, attr): + if inspect.ismethod(getattr(table, attr)): + return CompletionItemKind.METHOD.value + elif attr in table.__class__.__dict__ and isinstance( + table.__class__.__dict__[attr], property + ): + return CompletionItemKind.PROPERTY.value + else: + return CompletionItemKind.VARIABLE.value + + +# If we create a file once, with many lines of code, which the language +# server can then load and parse, then the autocomplete is a lot quicker +# than continually telling the language server about a line changing. +# Therefore for the parameterized tests, we first populate a file with +# all statements. + + +all_table_attributes = [ + (module, name, attr, get_completion_item_kind(table, attr)) + for module in get_submodules(tables) + for name, table in get_tables_from_namespace(module) + for attr in dir(table) + if not attr.startswith("_") +] + +expected_dropdown_entries = {} +for module, name, attr, completion_item_kind in all_table_attributes: + table = f"{module.__name__}.{name}" + if table not in expected_dropdown_entries: + expected_dropdown_entries[table] = {} + expected_dropdown_entries[table][attr] = completion_item_kind + +import_statements = [f"import {module.__name__}" for module in get_submodules(tables)] + + +@pytest.fixture(scope="session") +def language_server_preloaded_with_all_table_drop_downs( + language_server, tmp_path_factory +): + temp_file_path = tmp_path_factory.mktemp("autoocomplete") / "autocomplete.py" + content = [f"{table_name}." for table_name, _ in expected_dropdown_entries.items()] + language_server.open_doc(temp_file_path, "\n".join(import_statements + content)) + return language_server + + +@pytest.mark.parametrize( + "line, table_name, expected_values", + [ + (line, table_name, expected_values) + for line, (table_name, expected_values) in enumerate( + expected_dropdown_entries.items() + ) + ], +) +def test_dropdown_completion_for_all_tables( + language_server_preloaded_with_all_table_drop_downs, + line, + table_name, + expected_values, +): + # Check the autocompletion for each core table. I.e. + # what you get after typing: patients. + + results = language_server_preloaded_with_all_table_drop_downs.get_completion_results_from_file( + line + len(import_statements), len(table_name) + 1 + ) + + actual = {} + for x in results: + if not x.get("label").startswith("_"): + actual[x.get("label")] = x.get("kind") + + assert actual == expected_values + + +# If we add a new column type that is hard to support autocomplete +# we can add it here +ignored_columns = [ + # ("table_name", "column_name"), + ("addresses", "imd_decile"), + ("addresses", "imd_quintile"), +] + +all_columns_for_all_tables = [ + ( + f"{module.__name__}.{table_name}.{column_name}", + module, + table_name, + column_name, + type(getattr(table, column_name)).__name__, + ) + for module in get_submodules(tables) + for table_name, table in get_tables_from_namespace(module) + for column_name in dir(table) + if isinstance(getattr(table, column_name), BaseSeries) + and (table_name, column_name) not in ignored_columns +] + + +@pytest.fixture(scope="session") +def language_server_preloaded_with_all_types(language_server, tmp_path_factory): + temp_file_path = tmp_path_factory.mktemp("autoocomplete") / "autocomplete.py" + content = [thing[0] for thing in all_columns_for_all_tables] + language_server.open_doc(temp_file_path, "\n".join(import_statements + content)) + return language_server + + +@pytest.mark.parametrize( + "line, code_string, module, name, column_name, expected_series_type", + [ + (line, code_string, module, name, column_name, expected_series_type) + for line, ( + code_string, + module, + name, + column_name, + expected_series_type, + ) in enumerate(all_columns_for_all_tables) + ], +) +def test_types_of_all_series_on_all_tables( + language_server_preloaded_with_all_types, + line, + code_string, + module, + name, + column_name, + expected_series_type, +): + # For each Series on each table in each submodule of ehrql.tables we check + # that the inferred series type on hovering matches the expected series + result = language_server_preloaded_with_all_types.get_element_type_from_file( + line + len(import_statements), len(code_string) - 1 + ) + assert expected_series_type == result, ( + f"In {module.__name__}.{name}, the series `{column_name}` should be of type `{expected_series_type}` but the language server thinks it's `{result}`" + ) + + +@pytest.fixture(scope="session") +def language_server_preloaded_with_autocomplete_defs(language_server, tmp_path_factory): + temp_file_path = tmp_path_factory.mktemp("autoocomplete") / "autocomplete.py" + language_server.open_doc(temp_file_path, autocomplete_file) + return language_server + + +@pytest.mark.parametrize( + "statement, line_number, cursor_position, expected_type", + autocomplete_file_statements, +) +def test_things_in_autocomplete_definition_file( + language_server_preloaded_with_autocomplete_defs, + statement, + line_number, + cursor_position, + expected_type, +): + actual_type = ( + language_server_preloaded_with_autocomplete_defs.get_element_type_from_file( + line_number - 1, cursor_position + ) + ) + assert actual_type == expected_type, ( + f"\n\nautocomplete_definition.py, line {line_number}: Expecting {expected_type} (but got {actual_type}) for the statement\n{statement}\n" + ) + + +def test_all_table_methods(): + all_methods_from_tables = [ + f"{module.__name__}.{table_name}.{method}" + for module in get_submodules(tables) + for table_name, table in get_tables_from_namespace(module) + for method in table.__class__.__dict__ + if callable(getattr(table, method)) and not method.startswith("_") + ] + + # This is a list of table methods that we currently don't support for autocomplete, + # or need to ignore + ignored_table_methods = [ + "ehrql.tables.smoketest.patients.count_for_patient", + "ehrql.tables.smoketest.patients.exists_for_patient", + "ehrql.tables.emis.patients.age_on", # requires DateDifference + "ehrql.tables.tpp.addresses.for_patient_on", # Needs last_for_patient + "ehrql.tables.core.practice_registrations.for_patient_on", # Needs last_for_patient + "ehrql.tables.tpp.patients.age_on", # requires DateDifference + "ehrql.tables.core.patients.age_on", # requires DateDifference + "ehrql.tables.emis.ons_deaths.cause_of_death_is_in", # need refactor of method + "ehrql.tables.core.ons_deaths.cause_of_death_is_in", # need refactor of method + ] + tested_table_methods = [ + f"ehrql.tables.{calling_object}.{method}" + for (calling_object, method) in autocomplete_file_methods + ] + missing_methods = [ + method + for method in all_methods_from_tables + if method not in ignored_table_methods + tested_table_methods + ] + missing_methods_str = "\n - ".join(missing_methods) + assert not missing_methods, ( + "\nThe following methods do not have autocomplete tests:\n" + f" - {missing_methods_str}\n" + "You must add them to the `autocomplete_definition.py` file.\n" + "If we don't support autocomplete (yet) then add them to the " + "`ignored_table_methods` list in this test file." + ) + + +@pytest.fixture(scope="session") +def query_language_methods(): + """ + Fixture to extract all methods from the query_language.py file. + Determines for each method whether it is callable by a PatientSeries + or an EventSeries or both. + + Currently there is one special case for the decorator @int_property. + This appears as a method in the below, but is actually an attribute + after the decorator has done its job. Therefore we deal with them + separately. + """ + class_parents = {} + class_methods = {"object": []} + class_int_props = {"object": []} + series_methods = set() + series_int_props = set() + other_methods = {} + ql_classes = inspect.getmembers(query_language, inspect.isclass) + for class_name, class_obj in ql_classes: + if class_obj.__module__ != "ehrql.query_language": + continue + # Get the base classes using __mro__ + base_classes = [ + base.__name__ for base in class_obj.__mro__ if base.__name__ != class_name + ] + class_parents[class_obj] = base_classes + + # Get the methods defined in the class + methods = [ + method_name + for method_name, method in vars(class_obj).items() + if not method_name.startswith("_") and not isinstance(method, int_property) + ] + class_methods[class_name] = methods + + # Get the `int_property`s defined in the class + int_props = [ + method_name + for method_name, method in vars(class_obj).items() + if not method_name.startswith("_") and isinstance(method, int_property) + ] + class_int_props[class_name] = int_props + + for cls in class_parents: + if PatientSeries in cls.__mro__: + for parent in class_parents[cls]: + series_methods.update( + [ + (parent, method, PatientSeries) + for method in class_methods[parent] + ] + ) + series_int_props.update( + [ + (parent, method, PatientSeries) + for method in class_int_props[parent] + ] + ) + if EventSeries in cls.__mro__: + for parent in class_parents[cls]: + series_methods.update( + [(parent, method, EventSeries) for method in class_methods[parent]] + ) + series_int_props.update( + [ + (parent, method, EventSeries) + for method in class_int_props[parent] + ] + ) + if class_methods[cls.__name__]: + other_methods[cls.__name__] = class_methods[cls.__name__] + + # Currently `other_methods` contains all methods. So we remove the series methods + # to leave the non-series methods + for class_name, _, _ in series_methods: + if class_name in other_methods: + del other_methods[class_name] + + return (other_methods, series_methods, series_int_props) + + +def test_all_query_model_series_methods(query_language_methods): + _, series_methods, _ = query_language_methods + # This is a list of things we currently don't support for autocomplete. + ignored_query_language_methods = [ + ("BaseSeries", "map_values", PatientSeries), + ("BaseSeries", "map_values", EventSeries), + ("CodeFunctions", "to_category", PatientSeries), + ("CodeFunctions", "to_category", EventSeries), + ("MultiCodeStringFunctions", "is_in", PatientSeries), # need first_for_patient + ( + "MultiCodeStringFunctions", + "is_not_in", + PatientSeries, + ), # need first_for_patient + ( + "MultiCodeStringFunctions", + "contains", + PatientSeries, + ), # need first_for_patient + ( + "MultiCodeStringFunctions", + "contains_any_of", + PatientSeries, + ), # need first_for_patient + ] + + # To determine which methods are currently tested in the autocomplete_definition file we: + # - loop through each method pulled out of the AST walk earlier + # - for each method, evaluate the calling object - e.g. for core.patients.date_of_birth.is_after() + # the calling object is core.patients.date_of_birth, so we call eval() on the to get + # the type instance - in this case an instance of the DatePatientSeries class + # - we then loop through the __mro__ of the class of the instance method until we find a + # class with that method and assign that classes method as "tested". We stop looking at that + # point so that if e.g. A inherites from B, and A overrides the method c() on B, then we'd + # probably want to test instance_of_a.c() and instance_of_b.c() separately + tested_methods = [] + for calling_object, method in autocomplete_file_methods: + calling_object_instance = eval( + calling_object, + { + # Things imported in the autocomplete_def file, but only present + # in this file to be used here. We pass them as globals here. If + # we didn't, and just imported them above, Ruff would complain. + # This seems like a better compromise than telling Ruff to ignore + # the unused imports + "tpp": tpp, + "core": core, + "emis": emis, + "days": days, + "query_language": query_language, + }, + ) + + if isinstance(calling_object_instance, PatientSeries): + series_type = PatientSeries + elif isinstance(calling_object_instance, EventSeries): + series_type = EventSeries + else: + # There are non-Series methods e.g. methods directly on tables such as patients.age_on() + # and non-Series class methods such as Duration.days(). These are tested separately in + # other tests + continue + + for cls in calling_object_instance.__class__.__mro__: # pragma: no cover + if ( + cls.__name__, + method, + series_type, + ) in series_methods: + tested_methods.append((cls.__name__, method, series_type)) + break + + missing_methods = [ + f"{method_name} in {class_name} - called on a {series_type.__name__}" + for class_name, method_name, series_type in series_methods + if (class_name, method_name, series_type) + not in tested_methods + ignored_query_language_methods + ] + missing_methods_str = "\n - ".join(missing_methods) + + assert not missing_methods, ( + "\nThe following methods in query_language.py do not have autocomplete tests:\n" + f" - {missing_methods_str}\n" + "You must add them to the `autocomplete_definition.py` file.\n" + "If we don't support autocomplete (yet) then add them to the " + "`ignored_query_language_methods` list in this test file." + ) + + +def test_all_query_model_non_series_methods(query_language_methods): + other_methods, _, _ = query_language_methods + # Whole classes that we want to ignore from autocomplete + classes_to_ignore = [ + "Dataset", + "DateDifference", + "DummyDataConfig", + "SortedEventFrameMethods", + ] + # Specific class methods we want to ignore + ignored_query_language_methods = [ + ("EventFrame", "sort_by"), + ("when", "then"), + ("WhenThen", "otherwise"), + # Temporarily ignore this while it's non-public + ("EventTable", "add_column"), + ] + + # Same logic as in previous test to find methods actually tested in the autocomplete definition file + tested_methods = [] + for calling_object, method in autocomplete_file_methods: + calling_object_instance = eval( + calling_object, + { + # Things imported in the autocomplete_def file, but only present + # in this file to be used here. We pass them as globals here. If + # we didn't, and just imported them above, Ruff would complain. + # This seems like a better compromise than telling Ruff to ignore + # the unused imports + "tpp": tpp, + "core": core, + "emis": emis, + "days": days, + "query_language": query_language, + }, + ) + for cls in calling_object_instance.__class__.__mro__: + if cls.__name__ in other_methods: + tested_methods.append((cls.__name__, method)) + break + + missing_methods = [ + f"{method} in {class_name}" + for class_name, methods in other_methods.items() + for method in methods + if (class_name, method) not in tested_methods + ignored_query_language_methods + and class_name not in classes_to_ignore + ] + missing_methods_str = "\n - ".join(missing_methods) + + assert not missing_methods, ( + "\nThe following methods in query_language.py do not have autocomplete tests:\n" + f" - {missing_methods_str}\n" + "You must add them to the `autocomplete_definition.py` file.\n" + "If we don't support autocomplete (yet) then add them to the " + "`ignored_query_language_methods` list in this test file. Alternatively, " + "to ignore all methods from a class, add the class to the `classes_to_ignore` list." + ) + + +def test_all_query_model_int_properties(query_language_methods): + _, _, series_int_props = query_language_methods + + # Currently we test all int_props, but if in the future we don't you'd add + # them to this list + ignored_int_props = [ + # ("DateFunctions", "year", PatientSeries), + ] + tested_int_props = [] + for calling_object, attribute in autocomplete_file_attributes: + calling_object_instance = eval( + calling_object, + { + # Things imported in the autocomplete_def file, but only present + # in this file to be used here. We pass them as globals here. If + # we didn't, and just imported them above, Ruff would complain. + # This seems like a better compromise than telling Ruff to ignore + # the unused imports + "tpp": tpp, + "core": core, + "emis": emis, + "days": days, + "query_language": query_language, + }, + ) + + if isinstance(calling_object_instance, PatientSeries): + series_type = PatientSeries + elif isinstance(calling_object_instance, EventSeries): + series_type = EventSeries + else: + # There are non-Series attribues e.g. attribues directly on tables such as patients.sex + # These are tested separately in other tests + continue + + for cls in calling_object_instance.__class__.__mro__: # pragma: no cover + if ( + cls.__name__, + attribute, + series_type, + ) in series_int_props: + tested_int_props.append((cls.__name__, attribute, series_type)) + break + + missing_int_props = [ + f"{attribute} in {class_name} - called on a {series_type.__name__}" + for class_name, attribute, series_type in series_int_props + if (class_name, attribute, series_type) + not in tested_int_props + ignored_int_props + ] + missing_int_props_str = "\n - ".join(missing_int_props) + + assert not missing_int_props, ( + "\nThe following `int_property`s in query_language.py do not have autocomplete tests:\n" + f" - {missing_int_props_str}\n" + "You must add them to the `autocomplete_definition.py` file.\n" + "If we don't support autocomplete (yet) then add them to the " + "`ignored_int_props` list in this test file." + )