import gym
import numpy as np
import os
import pickle
import random
import tempfile
import zipfile
def zipsame(*seqs):
L = len(seqs[0])
assert all(len(seq) == L for seq in seqs[1:])
return zip(*seqs)
def unpack(seq, sizes):
"""
Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'.
None = just one bare element, not a list
Example:
unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6])
"""
seq = list(seq)
it = iter(seq)
assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes)
for size in sizes:
if size is None:
yield it.__next__()
else:
li = []
for _ in range(size):
li.append(it.__next__())
yield li
class EzPickle(object):
"""Objects that are pickled and unpickled via their constructor
arguments.
Example usage:
class Dog(Animal, EzPickle):
def __init__(self, furcolor, tailkind="bushy"):
Animal.__init__()
EzPickle.__init__(furcolor, tailkind)
...
When this object is unpickled, a new Dog will be constructed by passing the provided
furcolor and tailkind into the constructor. However, philosophers are still not sure
whether it is still the same dog.
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
and Atari.
"""
def __init__(self, *args, **kwargs):
self._ezpickle_args = args
self._ezpickle_kwargs = kwargs
def __getstate__(self):
return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs}
def __setstate__(self, d):
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
self.__dict__.update(out.__dict__)
def set_global_seeds(i):
try:
import tensorflow as tf
except ImportError:
pass
else:
tf.set_random_seed(i)
np.random.seed(i)
random.seed(i)
def pretty_eta(seconds_left):
"""Print the number of seconds in human readable format.
Examples:
2 days
2 hours and 37 minutes
less than a minute
Paramters
---------
seconds_left: int
Number of seconds to be converted to the ETA
Returns
-------
eta: str
String representing the pretty ETA.
"""
minutes_left = seconds_left // 60
seconds_left %= 60
hours_left = minutes_left // 60
minutes_left %= 60
days_left = hours_left // 24
hours_left %= 24
def helper(cnt, name):
return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else ''))
if days_left > 0:
msg = helper(days_left, 'day')
if hours_left > 0:
msg += ' and ' + helper(hours_left, 'hour')
return msg
if hours_left > 0:
msg = helper(hours_left, 'hour')
if minutes_left > 0:
msg += ' and ' + helper(minutes_left, 'minute')
return msg
if minutes_left > 0:
return helper(minutes_left, 'minute')
return 'less than a minute'
class RunningAvg(object):
def __init__(self, gamma, init_value=None):
"""Keep a running estimate of a quantity. This is a bit like mean
but more sensitive to recent changes.
Parameters
----------
gamma: float
Must be between 0 and 1, where 0 is the most sensitive to recent
changes.
init_value: float or None
Initial value of the estimate. If None, it will be set on the first update.
"""
self._value = init_value
self._gamma = gamma
def update(self, new_val):
"""Update the estimate.
Parameters
----------
new_val: float
new observated value of estimated quantity.
"""
if self._value is None:
self._value = new_val
else:
self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val
def __float__(self):
"""Get the current estimate"""
return self._value
def boolean_flag(parser, name, default=False, help=None):
"""Add a boolean flag to argparse parser.
Parameters
----------
parser: argparse.Parser
parser to add the flag to
name: str
--<name> will enable the flag, while --no-<name> will disable it
default: bool or None
default value of the flag
help: str
help string for the flag
"""
dest = name.replace('-', '_')
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
parser.add_argument("--no-" + name, action="store_false", dest=dest)
def get_wrapper_by_name(env, classname):
"""Given an a gym environment possibly wrapped multiple times, returns a wrapper
of class named classname or raises ValueError if no such wrapper was applied
Parameters
----------
env: gym.Env of gym.Wrapper
gym environment
classname: str
name of the wrapper
Returns
-------
wrapper: gym.Wrapper
wrapper named classname
"""
currentenv = env
while True:
if classname == currentenv.class_name():
return currentenv
elif isinstance(currentenv, gym.Wrapper):
currentenv = currentenv.env
else:
raise ValueError("Couldn't find wrapper named %s" % classname)
def relatively_safe_pickle_dump(obj, path, compression=False):
"""This is just like regular pickle dump, except from the fact that failure cases are
different:
- It's never possible that we end up with a pickle in corrupted state.
- If a there was a different file at the path, that file will remain unchanged in the
even of failure (provided that filesystem rename is atomic).
- it is sometimes possible that we end up with useless temp file which needs to be
deleted manually (it will be removed automatically on the next function call)
The indended use case is periodic checkpoints of experiment state, such that we never
corrupt previous checkpoints if the current one fails.
Parameters
----------
obj: object
object to pickle
path: str
path to the output file
compression: bool
if true pickle will be compressed
"""
temp_storage = path + ".relatively_safe"
if compression:
# Using gzip here would be simpler, but the size is limited to 2GB
with tempfile.NamedTemporaryFile() as uncompressed_file:
pickle.dump(obj, uncompressed_file)
uncompressed_file.file.flush()
with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip:
myzip.write(uncompressed_file.name, "data")
else:
with open(temp_storage, "wb") as f:
pickle.dump(obj, f)
os.rename(temp_storage, path)
def pickle_load(path, compression=False):
"""Unpickle a possible compressed pickle.
Parameters
----------
path: str
path to the output file
compression: bool
if true assumes that pickle was compressed when created and attempts decompression.
Returns
-------
obj: object
the unpickled object
"""
if compression:
with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip:
with myzip.open("data") as f:
return pickle.load(f)
else:
with open(path, "rb") as f:
return pickle.load(f)