a b/docs/scripts/griffe_ext.py
1
import ast
2
import importlib
3
import inspect
4
import logging
5
import sys
6
from typing import Union
7
8
import astunparse
9
from griffe import Extension, Object, ObjectNode
10
from griffe.docstrings.dataclasses import DocstringSectionParameters
11
from griffe.expressions import Expr
12
from griffe.logger import patch_loggers
13
14
15
def get_logger(name):
16
    new_logger = logging.getLogger(name)
17
    new_logger.setLevel("ERROR")
18
    return new_logger
19
20
21
patch_loggers(get_logger)
22
23
logger = get_logger(__name__)
24
25
26
class EDSNLPDocstrings(Extension):
27
    def __init__(self):
28
        super().__init__()
29
30
        self.PIPE_OBJ = {}
31
        self.FACT_MEM = {}
32
        self.PIPE_TO_FACT = {}
33
34
    def on_instance(self, node: Union[ast.AST, ObjectNode], obj: Object) -> None:
35
        if (
36
            isinstance(node, ast.Assign)
37
            and obj.name == "create_component"
38
            and isinstance(node.value, ast.Call)
39
            and isinstance(node.value.func, ast.Call)
40
        ):
41
            module_name = obj.path.rsplit(".", 1)[0]
42
            for name, mod in list(sys.modules.items()):
43
                if name.startswith("edspdf"):
44
                    importlib.reload(mod)
45
            module = importlib.reload(importlib.import_module(module_name))
46
47
            config_node = node.value.func
48
            config_node = next(
49
                (kw.value for kw in config_node.keywords if kw.arg == "default_config"),
50
                None,
51
            )
52
            try:
53
                default_config = eval(astunparse.unparse(config_node), module.__dict__)
54
            except Exception:
55
                default_config = {}
56
57
            # import object to get its evaluated docstring
58
            try:
59
                runtime_obj = getattr(module, obj.name)
60
                source = inspect.getsource(runtime_obj)
61
                self.visit(ast.parse(source))
62
            except ImportError:
63
                logger.debug(f"Could not get dynamic docstring for {obj.path}")
64
                return
65
            except AttributeError:
66
                logger.debug(f"Object {obj.path} does not have a __doc__ attribute")
67
                return
68
69
            callee = (
70
                runtime_obj.__init__
71
                if hasattr(runtime_obj, "__init__")
72
                else runtime_obj
73
            )
74
            spec = inspect.getfullargspec(callee)
75
            func_defaults = dict(
76
                zip(spec.args[-len(callee.__defaults__) :], callee.__defaults__)
77
                if callee.__defaults__
78
                else (),
79
                **(callee.__kwdefaults__ or {}),
80
            )
81
            defaults = {**func_defaults, **default_config}
82
            self.FACT_MEM[obj.path] = (node, obj, defaults)
83
            pipe_path = runtime_obj.__module__ + "." + runtime_obj.__name__
84
            self.PIPE_TO_FACT[pipe_path] = obj.path
85
86
            if pipe_path in self.PIPE_OBJ:
87
                pipe = self.PIPE_OBJ[pipe_path]
88
                obj.docstring = pipe.docstring
89
            else:
90
                return
91
        elif obj.is_class or obj.is_function:
92
            self.PIPE_OBJ[obj.path] = obj
93
            if obj.path in self.PIPE_TO_FACT:
94
                node, fact_obj, defaults = self.FACT_MEM[self.PIPE_TO_FACT[obj.path]]
95
                fact_obj.docstring = obj.docstring
96
                obj = fact_obj
97
            else:
98
                return
99
        else:
100
            return
101
102
        if obj.docstring is None:
103
            return
104
105
        param_section: DocstringSectionParameters = None
106
        obj.docstring.parser = "numpy"
107
        for section in obj.docstring.parsed:
108
            if isinstance(section, DocstringSectionParameters):
109
                param_section = section  # type: ignore
110
111
        if param_section is None:
112
            return
113
114
        for param in param_section.value:
115
            if param.name in defaults:
116
                param.default = str(defaults[param.name])
117
            if isinstance(param.default, Expr):
118
                continue
119
            if param.default is not None and len(param.default) > 50:
120
                param.default = param.default[: 50 - 3] + "..."