import ast
import inspect
import re
import sys
import textwrap
import warnings
import catalogue
import pytest
from spacy.tokens.underscore import Underscore
pytest.importorskip("mkdocs")
try:
import torch.nn
except ImportError:
torch = None
if torch is None:
pytest.skip("torch not installed", allow_module_level=True)
pytest.importorskip("rich")
from extract_docs_code import extract_docs_code # noqa: E402
# We don't check documentation for Python <= 3.7:
if sys.version_info < (3, 8):
url_to_code = {}
else:
url_to_code = dict(extract_docs_code())
# just to make sure something didn't go wrong
assert len(url_to_code) > 50
def printer(code: str) -> None:
"""
Prints a code bloc with lines for easier debugging.
Parameters
----------
code : str
Code bloc.
"""
lines = []
for i, line in enumerate(code.split("\n")):
lines.append(f"{i + 1:03} {line}")
print("\n".join(lines))
def insert_assert_statements(code):
line_table = [0]
for line in code.splitlines(keepends=True):
line_table.append(line_table[-1] + len(line))
tree = ast.parse(code)
replacements = []
for match in re.finditer(
r"^\s*#\s*Out\s*: (.*$(?:\n#\s.*$)*)", code, flags=re.MULTILINE
):
lineno = code[: match.start()].count("\n")
for stmt in tree.body:
if stmt.end_lineno == lineno:
if isinstance(stmt, ast.Expr):
expected = textwrap.dedent(match.group(1)).replace("\n# ", "\n")
begin = line_table[stmt.lineno - 1]
if not (expected.startswith("'") or expected.startswith('"')):
expected = repr(expected)
end = match.end()
stmt_str = ast.unparse(stmt)
if stmt_str.startswith("print("):
stmt_str = stmt_str[len("print") :]
repl = f"""\
value = {stmt_str}
assert {expected} == str(value)
"""
replacements.append((begin, end, repl))
if isinstance(stmt, ast.For):
expected = textwrap.dedent(match.group(1)).split("\n# Out: ")
expected = [line.replace("\n# ", "\n") for line in expected]
begin = line_table[stmt.lineno - 1]
end = match.end()
stmt_str = ast.unparse(stmt).replace("print", "assert_print")
repl = f"""\
printed = []
{stmt_str}
assert {expected} == printed
"""
replacements.append((begin, end, repl))
for begin, end, repl in reversed(replacements):
code = code[:begin] + repl + code[end:]
return code
# TODO: once in a while, it can be interesting to run reset_imports for each code block,
# instead of only once and tests should still pass, but it's way slower.
@pytest.fixture(scope="module")
def reset_imports():
"""
Reset the imports for each test.
"""
# 1. Clear registered functions to avoid using cached ones
for k, m in list(catalogue.REGISTRY.items()):
mod = inspect.getmodule(m)
if mod is not None and mod.__name__.startswith("edsnlp"):
del catalogue.REGISTRY[k]
# Let's ensure that we "bump" into every possible warnings:
# 2. Remove all modules that start with edsnlp, to reimport them
for k in list(sys.modules):
if k.split(".")[0] == "edsnlp":
del sys.modules[k]
# 3. Delete spacy extensions to avoid error when re-importing
Underscore.span_extensions.clear()
Underscore.doc_extensions.clear()
Underscore.token_extensions.clear()
# Note the use of `str`, makes for pretty output
@pytest.mark.parametrize("url", sorted(url_to_code.keys()), ids=str)
def test_code_blocks(url, tmpdir, reset_imports):
code = url_to_code[url]
code_with_asserts = """
def assert_print(*args, sep=" ", end="\\n", file=None, flush=False):
printed.append((sep.join(map(str, args)) + end).rstrip('\\n'))
""" + insert_assert_statements(code)
assert "# Out:" not in code_with_asserts, (
"Unparsed asserts in {url}:\n" + code_with_asserts
)
# We'll import test_code_blocks from here
sys.path.insert(0, str(tmpdir))
test_file = tmpdir.join("test_code_blocks.py")
# Clear all warnings
warnings.resetwarnings()
try:
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.filterwarnings(module=".*endlines.*", action="ignore")
warnings.filterwarnings(
message="__package__ != __spec__.parent", action="ignore"
)
# First, forget test_code_blocks
sys.modules.pop("test_code_blocks", None)
# Then, reimport it, to let pytest do its assertion rewriting magic
test_file.write_text(code_with_asserts, encoding="utf-8")
import test_code_blocks # noqa: F401
exec(
compile(code_with_asserts, test_file, "exec"),
{"__MODULE__": "__main__"},
)
except Exception:
printer(code_with_asserts)
raise