|
a |
|
b/myosuite/utils/implement_for.py |
|
|
1 |
from __future__ import annotations |
|
|
2 |
import collections |
|
|
3 |
import inspect |
|
|
4 |
import sys |
|
|
5 |
from copy import copy |
|
|
6 |
from functools import wraps |
|
|
7 |
from importlib import import_module |
|
|
8 |
from typing import Union, Callable, Dict |
|
|
9 |
from packaging.version import parse |
|
|
10 |
|
|
|
11 |
class implement_for: |
|
|
12 |
"""A version decorator that checks the version in the environment and implements a function with the fitting one. |
|
|
13 |
|
|
|
14 |
If specified module is missing or there is no fitting implementation, call of the decorated function |
|
|
15 |
will lead to the explicit error. |
|
|
16 |
In case of intersected ranges, last fitting implementation is used. |
|
|
17 |
|
|
|
18 |
This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, |
|
|
19 |
numpy vs jax-numpy etc). |
|
|
20 |
|
|
|
21 |
Args: |
|
|
22 |
module_name (str or callable): version is checked for the module with this |
|
|
23 |
name (e.g. "gym"). If a callable is provided, it should return the |
|
|
24 |
module. |
|
|
25 |
from_version: version from which implementation is compatible. Can be open (None). |
|
|
26 |
to_version: version from which implementation is no longer compatible. Can be open (None). |
|
|
27 |
|
|
|
28 |
Examples: |
|
|
29 |
>>> @implement_for("gym", "0.13", "0.14") |
|
|
30 |
>>> def fun(self, x): |
|
|
31 |
... # Older gym versions will return x + 1 |
|
|
32 |
... return x + 1 |
|
|
33 |
... |
|
|
34 |
>>> @implement_for("gym", "0.14", "0.23") |
|
|
35 |
>>> def fun(self, x): |
|
|
36 |
... # More recent gym versions will return x + 2 |
|
|
37 |
... return x + 2 |
|
|
38 |
... |
|
|
39 |
>>> @implement_for(lambda: import_module("gym"), "0.23", None) |
|
|
40 |
>>> def fun(self, x): |
|
|
41 |
... # More recent gym versions will return x + 2 |
|
|
42 |
... return x + 2 |
|
|
43 |
... |
|
|
44 |
>>> @implement_for("gymnasium", "0.27", None) |
|
|
45 |
>>> def fun(self, x): |
|
|
46 |
... # If gymnasium is to be used instead of gym, x+3 will be returned |
|
|
47 |
... return x + 3 |
|
|
48 |
... |
|
|
49 |
|
|
|
50 |
This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. |
|
|
51 |
""" |
|
|
52 |
|
|
|
53 |
# Stores pointers to fitting implementations: dict[func_name] = func_pointer |
|
|
54 |
_implementations = {} |
|
|
55 |
_setters = [] |
|
|
56 |
_cache_modules = {} |
|
|
57 |
|
|
|
58 |
def __init__( |
|
|
59 |
self, |
|
|
60 |
module_name: Union[str, Callable], |
|
|
61 |
from_version: str = None, |
|
|
62 |
to_version: str = None, |
|
|
63 |
): |
|
|
64 |
self.module_name = module_name |
|
|
65 |
self.from_version = from_version |
|
|
66 |
self.to_version = to_version |
|
|
67 |
implement_for._setters.append(self) |
|
|
68 |
|
|
|
69 |
@staticmethod |
|
|
70 |
def check_version(version, from_version, to_version): |
|
|
71 |
return (from_version is None or parse(version) >= parse(from_version)) and ( |
|
|
72 |
to_version is None or parse(version) < parse(to_version) |
|
|
73 |
) |
|
|
74 |
|
|
|
75 |
@staticmethod |
|
|
76 |
def get_class_that_defined_method(f): |
|
|
77 |
"""Returns the class of a method, if it is defined, and None otherwise.""" |
|
|
78 |
out = f.__globals__.get(f.__qualname__.split(".")[0], None) |
|
|
79 |
return out |
|
|
80 |
|
|
|
81 |
@classmethod |
|
|
82 |
def get_func_name(cls, fn): |
|
|
83 |
# produces a name like torchrl.module.Class.method or torchrl.module.function |
|
|
84 |
first = str(fn).split(".")[0][len("<function ") :] |
|
|
85 |
last = str(fn).split(".")[1:] |
|
|
86 |
if last: |
|
|
87 |
first = [first] |
|
|
88 |
last[-1] = last[-1].split(" ")[0] |
|
|
89 |
else: |
|
|
90 |
last = [first.split(" ")[0]] |
|
|
91 |
first = [] |
|
|
92 |
return ".".join([fn.__module__] + first + last) |
|
|
93 |
|
|
|
94 |
def _get_cls(self, fn): |
|
|
95 |
cls = self.get_class_that_defined_method(fn) |
|
|
96 |
if cls is None: |
|
|
97 |
# class not yet defined |
|
|
98 |
return |
|
|
99 |
if cls.__class__.__name__ == "function": |
|
|
100 |
cls = inspect.getmodule(fn) |
|
|
101 |
return cls |
|
|
102 |
|
|
|
103 |
def module_set(self): |
|
|
104 |
"""Sets the function in its module, if it exists already.""" |
|
|
105 |
prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None) |
|
|
106 |
if prev_setter is not None: |
|
|
107 |
prev_setter.do_set = False |
|
|
108 |
type(self)._implementations[self.get_func_name(self.fn)] = self |
|
|
109 |
cls = self.get_class_that_defined_method(self.fn) |
|
|
110 |
if cls is not None: |
|
|
111 |
if cls.__class__.__name__ == "function": |
|
|
112 |
cls = inspect.getmodule(self.fn) |
|
|
113 |
else: |
|
|
114 |
# class not yet defined |
|
|
115 |
return |
|
|
116 |
setattr(cls, self.fn.__name__, self.fn) |
|
|
117 |
|
|
|
118 |
@classmethod |
|
|
119 |
def import_module(cls, module_name: Union[Callable, str]) -> str: |
|
|
120 |
"""Imports module and returns its version.""" |
|
|
121 |
if not callable(module_name): |
|
|
122 |
module = cls._cache_modules.get(module_name, None) |
|
|
123 |
if module is None: |
|
|
124 |
if module_name in sys.modules: |
|
|
125 |
sys.modules[module_name] = module = import_module(module_name) |
|
|
126 |
else: |
|
|
127 |
cls._cache_modules[module_name] = module = import_module( |
|
|
128 |
module_name |
|
|
129 |
) |
|
|
130 |
else: |
|
|
131 |
module = module_name() |
|
|
132 |
return module.__version__ |
|
|
133 |
|
|
|
134 |
_lazy_impl = collections.defaultdict(list) |
|
|
135 |
|
|
|
136 |
def _delazify(self, func_name): |
|
|
137 |
for local_call in implement_for._lazy_impl[func_name]: |
|
|
138 |
out = local_call() |
|
|
139 |
return out |
|
|
140 |
|
|
|
141 |
def __call__(self, fn): |
|
|
142 |
# function names are unique |
|
|
143 |
self.func_name = self.get_func_name(fn) |
|
|
144 |
self.fn = fn |
|
|
145 |
implement_for._lazy_impl[self.func_name].append(self._call) |
|
|
146 |
|
|
|
147 |
@wraps(fn) |
|
|
148 |
def _lazy_call_fn(*args, **kwargs): |
|
|
149 |
# first time we call the function, we also do the replacement. |
|
|
150 |
# This will cause the imports to occur only during the first call to fn |
|
|
151 |
return self._delazify(self.func_name)(*args, **kwargs) |
|
|
152 |
|
|
|
153 |
return _lazy_call_fn |
|
|
154 |
|
|
|
155 |
def _call(self): |
|
|
156 |
|
|
|
157 |
# If the module is missing replace the function with the mock. |
|
|
158 |
fn = self.fn |
|
|
159 |
func_name = self.func_name |
|
|
160 |
implementations = implement_for._implementations |
|
|
161 |
|
|
|
162 |
@wraps(fn) |
|
|
163 |
def unsupported(*args, **kwargs): |
|
|
164 |
raise ModuleNotFoundError( |
|
|
165 |
f"Supported version of '{func_name}' has not been found." |
|
|
166 |
) |
|
|
167 |
|
|
|
168 |
self.do_set = False |
|
|
169 |
# Return fitting implementation if it was encountered before. |
|
|
170 |
if func_name in implementations: |
|
|
171 |
try: |
|
|
172 |
# check that backends don't conflict |
|
|
173 |
version = self.import_module(self.module_name) |
|
|
174 |
if self.check_version(version, self.from_version, self.to_version): |
|
|
175 |
self.do_set = True |
|
|
176 |
if not self.do_set: |
|
|
177 |
return implementations[func_name].fn |
|
|
178 |
except ModuleNotFoundError: |
|
|
179 |
# then it's ok, there is no conflict |
|
|
180 |
return implementations[func_name].fn |
|
|
181 |
else: |
|
|
182 |
try: |
|
|
183 |
version = self.import_module(self.module_name) |
|
|
184 |
if self.check_version(version, self.from_version, self.to_version): |
|
|
185 |
self.do_set = True |
|
|
186 |
except ModuleNotFoundError: |
|
|
187 |
return unsupported |
|
|
188 |
if self.do_set: |
|
|
189 |
self.module_set() |
|
|
190 |
return fn |
|
|
191 |
return unsupported |
|
|
192 |
|
|
|
193 |
@classmethod |
|
|
194 |
def reset(cls, setters_dict: Dict[str, implement_for] = None): |
|
|
195 |
"""Resets the setters in setter_dict. |
|
|
196 |
|
|
|
197 |
``setter_dict`` is a copy of implementations. We just need to iterate through its |
|
|
198 |
values and call :meth:`~.module_set` for each. |
|
|
199 |
|
|
|
200 |
""" |
|
|
201 |
if setters_dict is None: |
|
|
202 |
setters_dict = copy(cls._implementations) |
|
|
203 |
for setter in setters_dict.values(): |
|
|
204 |
setter.module_set() |
|
|
205 |
|
|
|
206 |
def __repr__(self): |
|
|
207 |
return ( |
|
|
208 |
f"{self.__class__.__name__}(" |
|
|
209 |
f"module_name={self.module_name}({self.from_version, self.to_version}), " |
|
|
210 |
f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})" |
|
|
211 |
) |