Diff of /unimol/data/data_utils.py [000000] .. [b40915]

Switch to side-by-side view

--- a
+++ b/unimol/data/data_utils.py
@@ -0,0 +1,23 @@
+# Copyright (c) DP Technology.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import contextlib
+
+
+@contextlib.contextmanager
+def numpy_seed(seed, *addl_seeds):
+    """Context manager which seeds the NumPy PRNG with the specified seed and
+    restores the state afterward"""
+    if seed is None:
+        yield
+        return
+    if len(addl_seeds) > 0:
+        seed = int(hash((seed, *addl_seeds)) % 1e6)
+    state = np.random.get_state()
+    np.random.seed(seed)
+    try:
+        yield
+    finally:
+        np.random.set_state(state)