Diff of /tests/test_docs.py [000000] .. [cad161]

Switch to side-by-side view

--- a
+++ b/tests/test_docs.py
@@ -0,0 +1,161 @@
+import ast
+import inspect
+import re
+import sys
+import textwrap
+import warnings
+
+import catalogue
+import pytest
+from spacy.tokens.underscore import Underscore
+
+pytest.importorskip("mkdocs")
+try:
+    import torch.nn
+except ImportError:
+    torch = None
+
+if torch is None:
+    pytest.skip("torch not installed", allow_module_level=True)
+pytest.importorskip("rich")
+
+from extract_docs_code import extract_docs_code  # noqa: E402
+
+# We don't check documentation for Python <= 3.7:
+if sys.version_info < (3, 8):
+    url_to_code = {}
+else:
+    url_to_code = dict(extract_docs_code())
+    # just to make sure something didn't go wrong
+    assert len(url_to_code) > 50
+
+
+def printer(code: str) -> None:
+    """
+    Prints a code bloc with lines for easier debugging.
+
+    Parameters
+    ----------
+    code : str
+        Code bloc.
+    """
+    lines = []
+    for i, line in enumerate(code.split("\n")):
+        lines.append(f"{i + 1:03}  {line}")
+
+    print("\n".join(lines))
+
+
+def insert_assert_statements(code):
+    line_table = [0]
+    for line in code.splitlines(keepends=True):
+        line_table.append(line_table[-1] + len(line))
+
+    tree = ast.parse(code)
+    replacements = []
+
+    for match in re.finditer(
+        r"^\s*#\s*Out\s*: (.*$(?:\n#\s.*$)*)", code, flags=re.MULTILINE
+    ):
+        lineno = code[: match.start()].count("\n")
+        for stmt in tree.body:
+            if stmt.end_lineno == lineno:
+                if isinstance(stmt, ast.Expr):
+                    expected = textwrap.dedent(match.group(1)).replace("\n# ", "\n")
+                    begin = line_table[stmt.lineno - 1]
+                    if not (expected.startswith("'") or expected.startswith('"')):
+                        expected = repr(expected)
+                    end = match.end()
+                    stmt_str = ast.unparse(stmt)
+                    if stmt_str.startswith("print("):
+                        stmt_str = stmt_str[len("print") :]
+                    repl = f"""\
+value = {stmt_str}
+assert {expected} == str(value)
+"""
+                    replacements.append((begin, end, repl))
+                if isinstance(stmt, ast.For):
+                    expected = textwrap.dedent(match.group(1)).split("\n# Out: ")
+                    expected = [line.replace("\n# ", "\n") for line in expected]
+                    begin = line_table[stmt.lineno - 1]
+                    end = match.end()
+                    stmt_str = ast.unparse(stmt).replace("print", "assert_print")
+                    repl = f"""\
+printed = []
+{stmt_str}
+assert {expected} == printed
+"""
+                    replacements.append((begin, end, repl))
+
+    for begin, end, repl in reversed(replacements):
+        code = code[:begin] + repl + code[end:]
+
+    return code
+
+
+# TODO: once in a while, it can be interesting to run reset_imports for each code block,
+#  instead of only once and tests should still pass, but it's way slower.
+@pytest.fixture(scope="module")
+def reset_imports():
+    """
+    Reset the imports for each test.
+    """
+    # 1. Clear registered functions to avoid using cached ones
+    for k, m in list(catalogue.REGISTRY.items()):
+        mod = inspect.getmodule(m)
+        if mod is not None and mod.__name__.startswith("edsnlp"):
+            del catalogue.REGISTRY[k]
+
+    # Let's ensure that we "bump" into every possible warnings:
+    # 2. Remove all modules that start with edsnlp, to reimport them
+    for k in list(sys.modules):
+        if k.split(".")[0] == "edsnlp":
+            del sys.modules[k]
+
+    # 3. Delete spacy extensions to avoid error when re-importing
+    Underscore.span_extensions.clear()
+    Underscore.doc_extensions.clear()
+    Underscore.token_extensions.clear()
+
+
+# Note the use of `str`, makes for pretty output
+@pytest.mark.parametrize("url", sorted(url_to_code.keys()), ids=str)
+def test_code_blocks(url, tmpdir, reset_imports):
+    code = url_to_code[url]
+    code_with_asserts = """
+def assert_print(*args, sep=" ", end="\\n", file=None, flush=False):
+    printed.append((sep.join(map(str, args)) + end).rstrip('\\n'))
+
+""" + insert_assert_statements(code)
+    assert "# Out:" not in code_with_asserts, (
+        "Unparsed asserts in {url}:\n" + code_with_asserts
+    )
+    # We'll import test_code_blocks from here
+    sys.path.insert(0, str(tmpdir))
+    test_file = tmpdir.join("test_code_blocks.py")
+
+    # Clear all warnings
+    warnings.resetwarnings()
+
+    try:
+        with warnings.catch_warnings():
+            warnings.simplefilter("error")
+            warnings.filterwarnings(module=".*endlines.*", action="ignore")
+            warnings.filterwarnings(
+                message="__package__ != __spec__.parent", action="ignore"
+            )
+            # First, forget test_code_blocks
+            sys.modules.pop("test_code_blocks", None)
+
+            # Then, reimport it, to let pytest do its assertion rewriting magic
+            test_file.write_text(code_with_asserts, encoding="utf-8")
+
+            import test_code_blocks  # noqa: F401
+
+            exec(
+                compile(code_with_asserts, test_file, "exec"),
+                {"__MODULE__": "__main__"},
+            )
+    except Exception:
+        printer(code_with_asserts)
+        raise