|
a |
|
b/tests/test_docs.py |
|
|
1 |
import ast |
|
|
2 |
import inspect |
|
|
3 |
import re |
|
|
4 |
import sys |
|
|
5 |
import textwrap |
|
|
6 |
import warnings |
|
|
7 |
|
|
|
8 |
import catalogue |
|
|
9 |
import pytest |
|
|
10 |
from spacy.tokens.underscore import Underscore |
|
|
11 |
|
|
|
12 |
pytest.importorskip("mkdocs") |
|
|
13 |
try: |
|
|
14 |
import torch.nn |
|
|
15 |
except ImportError: |
|
|
16 |
torch = None |
|
|
17 |
|
|
|
18 |
if torch is None: |
|
|
19 |
pytest.skip("torch not installed", allow_module_level=True) |
|
|
20 |
pytest.importorskip("rich") |
|
|
21 |
|
|
|
22 |
from extract_docs_code import extract_docs_code # noqa: E402 |
|
|
23 |
|
|
|
24 |
# We don't check documentation for Python <= 3.7: |
|
|
25 |
if sys.version_info < (3, 8): |
|
|
26 |
url_to_code = {} |
|
|
27 |
else: |
|
|
28 |
url_to_code = dict(extract_docs_code()) |
|
|
29 |
# just to make sure something didn't go wrong |
|
|
30 |
assert len(url_to_code) > 50 |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
def printer(code: str) -> None: |
|
|
34 |
""" |
|
|
35 |
Prints a code bloc with lines for easier debugging. |
|
|
36 |
|
|
|
37 |
Parameters |
|
|
38 |
---------- |
|
|
39 |
code : str |
|
|
40 |
Code bloc. |
|
|
41 |
""" |
|
|
42 |
lines = [] |
|
|
43 |
for i, line in enumerate(code.split("\n")): |
|
|
44 |
lines.append(f"{i + 1:03} {line}") |
|
|
45 |
|
|
|
46 |
print("\n".join(lines)) |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def insert_assert_statements(code): |
|
|
50 |
line_table = [0] |
|
|
51 |
for line in code.splitlines(keepends=True): |
|
|
52 |
line_table.append(line_table[-1] + len(line)) |
|
|
53 |
|
|
|
54 |
tree = ast.parse(code) |
|
|
55 |
replacements = [] |
|
|
56 |
|
|
|
57 |
for match in re.finditer( |
|
|
58 |
r"^\s*#\s*Out\s*: (.*$(?:\n#\s.*$)*)", code, flags=re.MULTILINE |
|
|
59 |
): |
|
|
60 |
lineno = code[: match.start()].count("\n") |
|
|
61 |
for stmt in tree.body: |
|
|
62 |
if stmt.end_lineno == lineno: |
|
|
63 |
if isinstance(stmt, ast.Expr): |
|
|
64 |
expected = textwrap.dedent(match.group(1)).replace("\n# ", "\n") |
|
|
65 |
begin = line_table[stmt.lineno - 1] |
|
|
66 |
if not (expected.startswith("'") or expected.startswith('"')): |
|
|
67 |
expected = repr(expected) |
|
|
68 |
end = match.end() |
|
|
69 |
stmt_str = ast.unparse(stmt) |
|
|
70 |
if stmt_str.startswith("print("): |
|
|
71 |
stmt_str = stmt_str[len("print") :] |
|
|
72 |
repl = f"""\ |
|
|
73 |
value = {stmt_str} |
|
|
74 |
assert {expected} == str(value) |
|
|
75 |
""" |
|
|
76 |
replacements.append((begin, end, repl)) |
|
|
77 |
if isinstance(stmt, ast.For): |
|
|
78 |
expected = textwrap.dedent(match.group(1)).split("\n# Out: ") |
|
|
79 |
expected = [line.replace("\n# ", "\n") for line in expected] |
|
|
80 |
begin = line_table[stmt.lineno - 1] |
|
|
81 |
end = match.end() |
|
|
82 |
stmt_str = ast.unparse(stmt).replace("print", "assert_print") |
|
|
83 |
repl = f"""\ |
|
|
84 |
printed = [] |
|
|
85 |
{stmt_str} |
|
|
86 |
assert {expected} == printed |
|
|
87 |
""" |
|
|
88 |
replacements.append((begin, end, repl)) |
|
|
89 |
|
|
|
90 |
for begin, end, repl in reversed(replacements): |
|
|
91 |
code = code[:begin] + repl + code[end:] |
|
|
92 |
|
|
|
93 |
return code |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
# TODO: once in a while, it can be interesting to run reset_imports for each code block, |
|
|
97 |
# instead of only once and tests should still pass, but it's way slower. |
|
|
98 |
@pytest.fixture(scope="module") |
|
|
99 |
def reset_imports(): |
|
|
100 |
""" |
|
|
101 |
Reset the imports for each test. |
|
|
102 |
""" |
|
|
103 |
# 1. Clear registered functions to avoid using cached ones |
|
|
104 |
for k, m in list(catalogue.REGISTRY.items()): |
|
|
105 |
mod = inspect.getmodule(m) |
|
|
106 |
if mod is not None and mod.__name__.startswith("edsnlp"): |
|
|
107 |
del catalogue.REGISTRY[k] |
|
|
108 |
|
|
|
109 |
# Let's ensure that we "bump" into every possible warnings: |
|
|
110 |
# 2. Remove all modules that start with edsnlp, to reimport them |
|
|
111 |
for k in list(sys.modules): |
|
|
112 |
if k.split(".")[0] == "edsnlp": |
|
|
113 |
del sys.modules[k] |
|
|
114 |
|
|
|
115 |
# 3. Delete spacy extensions to avoid error when re-importing |
|
|
116 |
Underscore.span_extensions.clear() |
|
|
117 |
Underscore.doc_extensions.clear() |
|
|
118 |
Underscore.token_extensions.clear() |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
# Note the use of `str`, makes for pretty output |
|
|
122 |
@pytest.mark.parametrize("url", sorted(url_to_code.keys()), ids=str) |
|
|
123 |
def test_code_blocks(url, tmpdir, reset_imports): |
|
|
124 |
code = url_to_code[url] |
|
|
125 |
code_with_asserts = """ |
|
|
126 |
def assert_print(*args, sep=" ", end="\\n", file=None, flush=False): |
|
|
127 |
printed.append((sep.join(map(str, args)) + end).rstrip('\\n')) |
|
|
128 |
|
|
|
129 |
""" + insert_assert_statements(code) |
|
|
130 |
assert "# Out:" not in code_with_asserts, ( |
|
|
131 |
"Unparsed asserts in {url}:\n" + code_with_asserts |
|
|
132 |
) |
|
|
133 |
# We'll import test_code_blocks from here |
|
|
134 |
sys.path.insert(0, str(tmpdir)) |
|
|
135 |
test_file = tmpdir.join("test_code_blocks.py") |
|
|
136 |
|
|
|
137 |
# Clear all warnings |
|
|
138 |
warnings.resetwarnings() |
|
|
139 |
|
|
|
140 |
try: |
|
|
141 |
with warnings.catch_warnings(): |
|
|
142 |
warnings.simplefilter("error") |
|
|
143 |
warnings.filterwarnings(module=".*endlines.*", action="ignore") |
|
|
144 |
warnings.filterwarnings( |
|
|
145 |
message="__package__ != __spec__.parent", action="ignore" |
|
|
146 |
) |
|
|
147 |
# First, forget test_code_blocks |
|
|
148 |
sys.modules.pop("test_code_blocks", None) |
|
|
149 |
|
|
|
150 |
# Then, reimport it, to let pytest do its assertion rewriting magic |
|
|
151 |
test_file.write_text(code_with_asserts, encoding="utf-8") |
|
|
152 |
|
|
|
153 |
import test_code_blocks # noqa: F401 |
|
|
154 |
|
|
|
155 |
exec( |
|
|
156 |
compile(code_with_asserts, test_file, "exec"), |
|
|
157 |
{"__MODULE__": "__main__"}, |
|
|
158 |
) |
|
|
159 |
except Exception: |
|
|
160 |
printer(code_with_asserts) |
|
|
161 |
raise |