Diff of /tests/test_docs.py [000000] .. [cad161]

Switch to unified view

a b/tests/test_docs.py
1
import ast
2
import inspect
3
import re
4
import sys
5
import textwrap
6
import warnings
7
8
import catalogue
9
import pytest
10
from spacy.tokens.underscore import Underscore
11
12
pytest.importorskip("mkdocs")
13
try:
14
    import torch.nn
15
except ImportError:
16
    torch = None
17
18
if torch is None:
19
    pytest.skip("torch not installed", allow_module_level=True)
20
pytest.importorskip("rich")
21
22
from extract_docs_code import extract_docs_code  # noqa: E402
23
24
# We don't check documentation for Python <= 3.7:
25
if sys.version_info < (3, 8):
26
    url_to_code = {}
27
else:
28
    url_to_code = dict(extract_docs_code())
29
    # just to make sure something didn't go wrong
30
    assert len(url_to_code) > 50
31
32
33
def printer(code: str) -> None:
34
    """
35
    Prints a code bloc with lines for easier debugging.
36
37
    Parameters
38
    ----------
39
    code : str
40
        Code bloc.
41
    """
42
    lines = []
43
    for i, line in enumerate(code.split("\n")):
44
        lines.append(f"{i + 1:03}  {line}")
45
46
    print("\n".join(lines))
47
48
49
def insert_assert_statements(code):
50
    line_table = [0]
51
    for line in code.splitlines(keepends=True):
52
        line_table.append(line_table[-1] + len(line))
53
54
    tree = ast.parse(code)
55
    replacements = []
56
57
    for match in re.finditer(
58
        r"^\s*#\s*Out\s*: (.*$(?:\n#\s.*$)*)", code, flags=re.MULTILINE
59
    ):
60
        lineno = code[: match.start()].count("\n")
61
        for stmt in tree.body:
62
            if stmt.end_lineno == lineno:
63
                if isinstance(stmt, ast.Expr):
64
                    expected = textwrap.dedent(match.group(1)).replace("\n# ", "\n")
65
                    begin = line_table[stmt.lineno - 1]
66
                    if not (expected.startswith("'") or expected.startswith('"')):
67
                        expected = repr(expected)
68
                    end = match.end()
69
                    stmt_str = ast.unparse(stmt)
70
                    if stmt_str.startswith("print("):
71
                        stmt_str = stmt_str[len("print") :]
72
                    repl = f"""\
73
value = {stmt_str}
74
assert {expected} == str(value)
75
"""
76
                    replacements.append((begin, end, repl))
77
                if isinstance(stmt, ast.For):
78
                    expected = textwrap.dedent(match.group(1)).split("\n# Out: ")
79
                    expected = [line.replace("\n# ", "\n") for line in expected]
80
                    begin = line_table[stmt.lineno - 1]
81
                    end = match.end()
82
                    stmt_str = ast.unparse(stmt).replace("print", "assert_print")
83
                    repl = f"""\
84
printed = []
85
{stmt_str}
86
assert {expected} == printed
87
"""
88
                    replacements.append((begin, end, repl))
89
90
    for begin, end, repl in reversed(replacements):
91
        code = code[:begin] + repl + code[end:]
92
93
    return code
94
95
96
# TODO: once in a while, it can be interesting to run reset_imports for each code block,
97
#  instead of only once and tests should still pass, but it's way slower.
98
@pytest.fixture(scope="module")
99
def reset_imports():
100
    """
101
    Reset the imports for each test.
102
    """
103
    # 1. Clear registered functions to avoid using cached ones
104
    for k, m in list(catalogue.REGISTRY.items()):
105
        mod = inspect.getmodule(m)
106
        if mod is not None and mod.__name__.startswith("edsnlp"):
107
            del catalogue.REGISTRY[k]
108
109
    # Let's ensure that we "bump" into every possible warnings:
110
    # 2. Remove all modules that start with edsnlp, to reimport them
111
    for k in list(sys.modules):
112
        if k.split(".")[0] == "edsnlp":
113
            del sys.modules[k]
114
115
    # 3. Delete spacy extensions to avoid error when re-importing
116
    Underscore.span_extensions.clear()
117
    Underscore.doc_extensions.clear()
118
    Underscore.token_extensions.clear()
119
120
121
# Note the use of `str`, makes for pretty output
122
@pytest.mark.parametrize("url", sorted(url_to_code.keys()), ids=str)
123
def test_code_blocks(url, tmpdir, reset_imports):
124
    code = url_to_code[url]
125
    code_with_asserts = """
126
def assert_print(*args, sep=" ", end="\\n", file=None, flush=False):
127
    printed.append((sep.join(map(str, args)) + end).rstrip('\\n'))
128
129
""" + insert_assert_statements(code)
130
    assert "# Out:" not in code_with_asserts, (
131
        "Unparsed asserts in {url}:\n" + code_with_asserts
132
    )
133
    # We'll import test_code_blocks from here
134
    sys.path.insert(0, str(tmpdir))
135
    test_file = tmpdir.join("test_code_blocks.py")
136
137
    # Clear all warnings
138
    warnings.resetwarnings()
139
140
    try:
141
        with warnings.catch_warnings():
142
            warnings.simplefilter("error")
143
            warnings.filterwarnings(module=".*endlines.*", action="ignore")
144
            warnings.filterwarnings(
145
                message="__package__ != __spec__.parent", action="ignore"
146
            )
147
            # First, forget test_code_blocks
148
            sys.modules.pop("test_code_blocks", None)
149
150
            # Then, reimport it, to let pytest do its assertion rewriting magic
151
            test_file.write_text(code_with_asserts, encoding="utf-8")
152
153
            import test_code_blocks  # noqa: F401
154
155
            exec(
156
                compile(code_with_asserts, test_file, "exec"),
157
                {"__MODULE__": "__main__"},
158
            )
159
    except Exception:
160
        printer(code_with_asserts)
161
        raise