--- a +++ b/unimol/data/data_utils.py @@ -0,0 +1,23 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import contextlib + + +@contextlib.contextmanager +def numpy_seed(seed, *addl_seeds): + """Context manager which seeds the NumPy PRNG with the specified seed and + restores the state afterward""" + if seed is None: + yield + return + if len(addl_seeds) > 0: + seed = int(hash((seed, *addl_seeds)) % 1e6) + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state)