a b/myosuite/utils/implement_for.py
1
from __future__ import annotations
2
import collections
3
import inspect
4
import sys
5
from copy import copy
6
from functools import wraps
7
from importlib import import_module
8
from typing import Union, Callable, Dict
9
from packaging.version import parse
10
11
class implement_for:
12
    """A version decorator that checks the version in the environment and implements a function with the fitting one.
13
14
    If specified module is missing or there is no fitting implementation, call of the decorated function
15
    will lead to the explicit error.
16
    In case of intersected ranges, last fitting implementation is used.
17
18
    This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium,
19
    numpy vs jax-numpy etc).
20
21
    Args:
22
        module_name (str or callable): version is checked for the module with this
23
            name (e.g. "gym"). If a callable is provided, it should return the
24
            module.
25
        from_version: version from which implementation is compatible. Can be open (None).
26
        to_version: version from which implementation is no longer compatible. Can be open (None).
27
28
    Examples:
29
        >>> @implement_for("gym", "0.13", "0.14")
30
        >>> def fun(self, x):
31
        ...     # Older gym versions will return x + 1
32
        ...     return x + 1
33
        ...
34
        >>> @implement_for("gym", "0.14", "0.23")
35
        >>> def fun(self, x):
36
        ...     # More recent gym versions will return x + 2
37
        ...     return x + 2
38
        ...
39
        >>> @implement_for(lambda: import_module("gym"), "0.23", None)
40
        >>> def fun(self, x):
41
        ...     # More recent gym versions will return x + 2
42
        ...     return x + 2
43
        ...
44
        >>> @implement_for("gymnasium", "0.27", None)
45
        >>> def fun(self, x):
46
        ...     # If gymnasium is to be used instead of gym, x+3 will be returned
47
        ...     return x + 3
48
        ...
49
50
        This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+.
51
    """
52
53
    # Stores pointers to fitting implementations: dict[func_name] = func_pointer
54
    _implementations = {}
55
    _setters = []
56
    _cache_modules = {}
57
58
    def __init__(
59
        self,
60
        module_name: Union[str, Callable],
61
        from_version: str = None,
62
        to_version: str = None,
63
    ):
64
        self.module_name = module_name
65
        self.from_version = from_version
66
        self.to_version = to_version
67
        implement_for._setters.append(self)
68
69
    @staticmethod
70
    def check_version(version, from_version, to_version):
71
        return (from_version is None or parse(version) >= parse(from_version)) and (
72
            to_version is None or parse(version) < parse(to_version)
73
        )
74
75
    @staticmethod
76
    def get_class_that_defined_method(f):
77
        """Returns the class of a method, if it is defined, and None otherwise."""
78
        out = f.__globals__.get(f.__qualname__.split(".")[0], None)
79
        return out
80
81
    @classmethod
82
    def get_func_name(cls, fn):
83
        # produces a name like torchrl.module.Class.method or torchrl.module.function
84
        first = str(fn).split(".")[0][len("<function ") :]
85
        last = str(fn).split(".")[1:]
86
        if last:
87
            first = [first]
88
            last[-1] = last[-1].split(" ")[0]
89
        else:
90
            last = [first.split(" ")[0]]
91
            first = []
92
        return ".".join([fn.__module__] + first + last)
93
94
    def _get_cls(self, fn):
95
        cls = self.get_class_that_defined_method(fn)
96
        if cls is None:
97
            # class not yet defined
98
            return
99
        if cls.__class__.__name__ == "function":
100
            cls = inspect.getmodule(fn)
101
        return cls
102
103
    def module_set(self):
104
        """Sets the function in its module, if it exists already."""
105
        prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None)
106
        if prev_setter is not None:
107
            prev_setter.do_set = False
108
        type(self)._implementations[self.get_func_name(self.fn)] = self
109
        cls = self.get_class_that_defined_method(self.fn)
110
        if cls is not None:
111
            if cls.__class__.__name__ == "function":
112
                cls = inspect.getmodule(self.fn)
113
        else:
114
            # class not yet defined
115
            return
116
        setattr(cls, self.fn.__name__, self.fn)
117
118
    @classmethod
119
    def import_module(cls, module_name: Union[Callable, str]) -> str:
120
        """Imports module and returns its version."""
121
        if not callable(module_name):
122
            module = cls._cache_modules.get(module_name, None)
123
            if module is None:
124
                if module_name in sys.modules:
125
                    sys.modules[module_name] = module = import_module(module_name)
126
                else:
127
                    cls._cache_modules[module_name] = module = import_module(
128
                        module_name
129
                    )
130
        else:
131
            module = module_name()
132
        return module.__version__
133
134
    _lazy_impl = collections.defaultdict(list)
135
136
    def _delazify(self, func_name):
137
        for local_call in implement_for._lazy_impl[func_name]:
138
            out = local_call()
139
        return out
140
141
    def __call__(self, fn):
142
        # function names are unique
143
        self.func_name = self.get_func_name(fn)
144
        self.fn = fn
145
        implement_for._lazy_impl[self.func_name].append(self._call)
146
147
        @wraps(fn)
148
        def _lazy_call_fn(*args, **kwargs):
149
            # first time we call the function, we also do the replacement.
150
            # This will cause the imports to occur only during the first call to fn
151
            return self._delazify(self.func_name)(*args, **kwargs)
152
153
        return _lazy_call_fn
154
155
    def _call(self):
156
157
        # If the module is missing replace the function with the mock.
158
        fn = self.fn
159
        func_name = self.func_name
160
        implementations = implement_for._implementations
161
162
        @wraps(fn)
163
        def unsupported(*args, **kwargs):
164
            raise ModuleNotFoundError(
165
                f"Supported version of '{func_name}' has not been found."
166
            )
167
168
        self.do_set = False
169
        # Return fitting implementation if it was encountered before.
170
        if func_name in implementations:
171
            try:
172
                # check that backends don't conflict
173
                version = self.import_module(self.module_name)
174
                if self.check_version(version, self.from_version, self.to_version):
175
                    self.do_set = True
176
                if not self.do_set:
177
                    return implementations[func_name].fn
178
            except ModuleNotFoundError:
179
                # then it's ok, there is no conflict
180
                return implementations[func_name].fn
181
        else:
182
            try:
183
                version = self.import_module(self.module_name)
184
                if self.check_version(version, self.from_version, self.to_version):
185
                    self.do_set = True
186
            except ModuleNotFoundError:
187
                return unsupported
188
        if self.do_set:
189
            self.module_set()
190
            return fn
191
        return unsupported
192
193
    @classmethod
194
    def reset(cls, setters_dict: Dict[str, implement_for] = None):
195
        """Resets the setters in setter_dict.
196
197
        ``setter_dict`` is a copy of implementations. We just need to iterate through its
198
        values and call :meth:`~.module_set` for each.
199
200
        """
201
        if setters_dict is None:
202
            setters_dict = copy(cls._implementations)
203
        for setter in setters_dict.values():
204
            setter.module_set()
205
206
    def __repr__(self):
207
        return (
208
            f"{self.__class__.__name__}("
209
            f"module_name={self.module_name}({self.from_version, self.to_version}), "
210
            f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})"
211
        )