|
a |
|
b/docs/scripts/griffe_ext.py |
|
|
1 |
import ast |
|
|
2 |
import importlib |
|
|
3 |
import inspect |
|
|
4 |
import logging |
|
|
5 |
import sys |
|
|
6 |
from typing import Union |
|
|
7 |
|
|
|
8 |
import astunparse |
|
|
9 |
from griffe import Extension, Object, ObjectNode |
|
|
10 |
from griffe.docstrings.dataclasses import DocstringSectionParameters |
|
|
11 |
from griffe.expressions import Expr |
|
|
12 |
from griffe.logger import patch_loggers |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def get_logger(name): |
|
|
16 |
new_logger = logging.getLogger(name) |
|
|
17 |
new_logger.setLevel("ERROR") |
|
|
18 |
return new_logger |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
patch_loggers(get_logger) |
|
|
22 |
|
|
|
23 |
logger = get_logger(__name__) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
class EDSNLPDocstrings(Extension): |
|
|
27 |
def __init__(self): |
|
|
28 |
super().__init__() |
|
|
29 |
|
|
|
30 |
self.PIPE_OBJ = {} |
|
|
31 |
self.FACT_MEM = {} |
|
|
32 |
self.PIPE_TO_FACT = {} |
|
|
33 |
|
|
|
34 |
def on_instance(self, node: Union[ast.AST, ObjectNode], obj: Object) -> None: |
|
|
35 |
if ( |
|
|
36 |
isinstance(node, ast.Assign) |
|
|
37 |
and obj.name == "create_component" |
|
|
38 |
and isinstance(node.value, ast.Call) |
|
|
39 |
and isinstance(node.value.func, ast.Call) |
|
|
40 |
): |
|
|
41 |
module_name = obj.path.rsplit(".", 1)[0] |
|
|
42 |
for name, mod in list(sys.modules.items()): |
|
|
43 |
if name.startswith("edspdf"): |
|
|
44 |
importlib.reload(mod) |
|
|
45 |
module = importlib.reload(importlib.import_module(module_name)) |
|
|
46 |
|
|
|
47 |
config_node = node.value.func |
|
|
48 |
config_node = next( |
|
|
49 |
(kw.value for kw in config_node.keywords if kw.arg == "default_config"), |
|
|
50 |
None, |
|
|
51 |
) |
|
|
52 |
try: |
|
|
53 |
default_config = eval(astunparse.unparse(config_node), module.__dict__) |
|
|
54 |
except Exception: |
|
|
55 |
default_config = {} |
|
|
56 |
|
|
|
57 |
# import object to get its evaluated docstring |
|
|
58 |
try: |
|
|
59 |
runtime_obj = getattr(module, obj.name) |
|
|
60 |
source = inspect.getsource(runtime_obj) |
|
|
61 |
self.visit(ast.parse(source)) |
|
|
62 |
except ImportError: |
|
|
63 |
logger.debug(f"Could not get dynamic docstring for {obj.path}") |
|
|
64 |
return |
|
|
65 |
except AttributeError: |
|
|
66 |
logger.debug(f"Object {obj.path} does not have a __doc__ attribute") |
|
|
67 |
return |
|
|
68 |
|
|
|
69 |
callee = ( |
|
|
70 |
runtime_obj.__init__ |
|
|
71 |
if hasattr(runtime_obj, "__init__") |
|
|
72 |
else runtime_obj |
|
|
73 |
) |
|
|
74 |
spec = inspect.getfullargspec(callee) |
|
|
75 |
func_defaults = dict( |
|
|
76 |
zip(spec.args[-len(callee.__defaults__) :], callee.__defaults__) |
|
|
77 |
if callee.__defaults__ |
|
|
78 |
else (), |
|
|
79 |
**(callee.__kwdefaults__ or {}), |
|
|
80 |
) |
|
|
81 |
defaults = {**func_defaults, **default_config} |
|
|
82 |
self.FACT_MEM[obj.path] = (node, obj, defaults) |
|
|
83 |
pipe_path = runtime_obj.__module__ + "." + runtime_obj.__name__ |
|
|
84 |
self.PIPE_TO_FACT[pipe_path] = obj.path |
|
|
85 |
|
|
|
86 |
if pipe_path in self.PIPE_OBJ: |
|
|
87 |
pipe = self.PIPE_OBJ[pipe_path] |
|
|
88 |
obj.docstring = pipe.docstring |
|
|
89 |
else: |
|
|
90 |
return |
|
|
91 |
elif obj.is_class or obj.is_function: |
|
|
92 |
self.PIPE_OBJ[obj.path] = obj |
|
|
93 |
if obj.path in self.PIPE_TO_FACT: |
|
|
94 |
node, fact_obj, defaults = self.FACT_MEM[self.PIPE_TO_FACT[obj.path]] |
|
|
95 |
fact_obj.docstring = obj.docstring |
|
|
96 |
obj = fact_obj |
|
|
97 |
else: |
|
|
98 |
return |
|
|
99 |
else: |
|
|
100 |
return |
|
|
101 |
|
|
|
102 |
if obj.docstring is None: |
|
|
103 |
return |
|
|
104 |
|
|
|
105 |
param_section: DocstringSectionParameters = None |
|
|
106 |
obj.docstring.parser = "numpy" |
|
|
107 |
for section in obj.docstring.parsed: |
|
|
108 |
if isinstance(section, DocstringSectionParameters): |
|
|
109 |
param_section = section # type: ignore |
|
|
110 |
|
|
|
111 |
if param_section is None: |
|
|
112 |
return |
|
|
113 |
|
|
|
114 |
for param in param_section.value: |
|
|
115 |
if param.name in defaults: |
|
|
116 |
param.default = str(defaults[param.name]) |
|
|
117 |
if isinstance(param.default, Expr): |
|
|
118 |
continue |
|
|
119 |
if param.default is not None and len(param.default) > 50: |
|
|
120 |
param.default = param.default[: 50 - 3] + "..." |