Switch to side-by-side view

--- a
+++ b/myosuite/utils/implement_for.py
@@ -0,0 +1,211 @@
+from __future__ import annotations
+import collections
+import inspect
+import sys
+from copy import copy
+from functools import wraps
+from importlib import import_module
+from typing import Union, Callable, Dict
+from packaging.version import parse
+
+class implement_for:
+    """A version decorator that checks the version in the environment and implements a function with the fitting one.
+
+    If specified module is missing or there is no fitting implementation, call of the decorated function
+    will lead to the explicit error.
+    In case of intersected ranges, last fitting implementation is used.
+
+    This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium,
+    numpy vs jax-numpy etc).
+
+    Args:
+        module_name (str or callable): version is checked for the module with this
+            name (e.g. "gym"). If a callable is provided, it should return the
+            module.
+        from_version: version from which implementation is compatible. Can be open (None).
+        to_version: version from which implementation is no longer compatible. Can be open (None).
+
+    Examples:
+        >>> @implement_for("gym", "0.13", "0.14")
+        >>> def fun(self, x):
+        ...     # Older gym versions will return x + 1
+        ...     return x + 1
+        ...
+        >>> @implement_for("gym", "0.14", "0.23")
+        >>> def fun(self, x):
+        ...     # More recent gym versions will return x + 2
+        ...     return x + 2
+        ...
+        >>> @implement_for(lambda: import_module("gym"), "0.23", None)
+        >>> def fun(self, x):
+        ...     # More recent gym versions will return x + 2
+        ...     return x + 2
+        ...
+        >>> @implement_for("gymnasium", "0.27", None)
+        >>> def fun(self, x):
+        ...     # If gymnasium is to be used instead of gym, x+3 will be returned
+        ...     return x + 3
+        ...
+
+        This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+.
+    """
+
+    # Stores pointers to fitting implementations: dict[func_name] = func_pointer
+    _implementations = {}
+    _setters = []
+    _cache_modules = {}
+
+    def __init__(
+        self,
+        module_name: Union[str, Callable],
+        from_version: str = None,
+        to_version: str = None,
+    ):
+        self.module_name = module_name
+        self.from_version = from_version
+        self.to_version = to_version
+        implement_for._setters.append(self)
+
+    @staticmethod
+    def check_version(version, from_version, to_version):
+        return (from_version is None or parse(version) >= parse(from_version)) and (
+            to_version is None or parse(version) < parse(to_version)
+        )
+
+    @staticmethod
+    def get_class_that_defined_method(f):
+        """Returns the class of a method, if it is defined, and None otherwise."""
+        out = f.__globals__.get(f.__qualname__.split(".")[0], None)
+        return out
+
+    @classmethod
+    def get_func_name(cls, fn):
+        # produces a name like torchrl.module.Class.method or torchrl.module.function
+        first = str(fn).split(".")[0][len("<function ") :]
+        last = str(fn).split(".")[1:]
+        if last:
+            first = [first]
+            last[-1] = last[-1].split(" ")[0]
+        else:
+            last = [first.split(" ")[0]]
+            first = []
+        return ".".join([fn.__module__] + first + last)
+
+    def _get_cls(self, fn):
+        cls = self.get_class_that_defined_method(fn)
+        if cls is None:
+            # class not yet defined
+            return
+        if cls.__class__.__name__ == "function":
+            cls = inspect.getmodule(fn)
+        return cls
+
+    def module_set(self):
+        """Sets the function in its module, if it exists already."""
+        prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None)
+        if prev_setter is not None:
+            prev_setter.do_set = False
+        type(self)._implementations[self.get_func_name(self.fn)] = self
+        cls = self.get_class_that_defined_method(self.fn)
+        if cls is not None:
+            if cls.__class__.__name__ == "function":
+                cls = inspect.getmodule(self.fn)
+        else:
+            # class not yet defined
+            return
+        setattr(cls, self.fn.__name__, self.fn)
+
+    @classmethod
+    def import_module(cls, module_name: Union[Callable, str]) -> str:
+        """Imports module and returns its version."""
+        if not callable(module_name):
+            module = cls._cache_modules.get(module_name, None)
+            if module is None:
+                if module_name in sys.modules:
+                    sys.modules[module_name] = module = import_module(module_name)
+                else:
+                    cls._cache_modules[module_name] = module = import_module(
+                        module_name
+                    )
+        else:
+            module = module_name()
+        return module.__version__
+
+    _lazy_impl = collections.defaultdict(list)
+
+    def _delazify(self, func_name):
+        for local_call in implement_for._lazy_impl[func_name]:
+            out = local_call()
+        return out
+
+    def __call__(self, fn):
+        # function names are unique
+        self.func_name = self.get_func_name(fn)
+        self.fn = fn
+        implement_for._lazy_impl[self.func_name].append(self._call)
+
+        @wraps(fn)
+        def _lazy_call_fn(*args, **kwargs):
+            # first time we call the function, we also do the replacement.
+            # This will cause the imports to occur only during the first call to fn
+            return self._delazify(self.func_name)(*args, **kwargs)
+
+        return _lazy_call_fn
+
+    def _call(self):
+
+        # If the module is missing replace the function with the mock.
+        fn = self.fn
+        func_name = self.func_name
+        implementations = implement_for._implementations
+
+        @wraps(fn)
+        def unsupported(*args, **kwargs):
+            raise ModuleNotFoundError(
+                f"Supported version of '{func_name}' has not been found."
+            )
+
+        self.do_set = False
+        # Return fitting implementation if it was encountered before.
+        if func_name in implementations:
+            try:
+                # check that backends don't conflict
+                version = self.import_module(self.module_name)
+                if self.check_version(version, self.from_version, self.to_version):
+                    self.do_set = True
+                if not self.do_set:
+                    return implementations[func_name].fn
+            except ModuleNotFoundError:
+                # then it's ok, there is no conflict
+                return implementations[func_name].fn
+        else:
+            try:
+                version = self.import_module(self.module_name)
+                if self.check_version(version, self.from_version, self.to_version):
+                    self.do_set = True
+            except ModuleNotFoundError:
+                return unsupported
+        if self.do_set:
+            self.module_set()
+            return fn
+        return unsupported
+
+    @classmethod
+    def reset(cls, setters_dict: Dict[str, implement_for] = None):
+        """Resets the setters in setter_dict.
+
+        ``setter_dict`` is a copy of implementations. We just need to iterate through its
+        values and call :meth:`~.module_set` for each.
+
+        """
+        if setters_dict is None:
+            setters_dict = copy(cls._implementations)
+        for setter in setters_dict.values():
+            setter.module_set()
+
+    def __repr__(self):
+        return (
+            f"{self.__class__.__name__}("
+            f"module_name={self.module_name}({self.from_version, self.to_version}), "
+            f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})"
+        )