--- a +++ b/nips/remote_vec_env.py @@ -0,0 +1,112 @@ +import numpy as np +from baselines.common.vec_env import VecEnv +import ray + +from nips.round2_env import OBSERVATION_SPACE + + +class TaskPool(object): + """Helper class for tracking the status of many in-flight actor tasks.""" + + def __init__(self, timeout=1): + self._tasks = {} + self.timeout = timeout + + def add(self, worker, obj_id): + self._tasks[obj_id] = worker + + def completed(self): + pending = list(self._tasks) + if pending: + ready, _ = ray.wait(pending, num_returns=len(pending), timeout=self.timeout) + if not ready: + return [] + for obj_id in ready: + yield (self._tasks.pop(obj_id), obj_id) + + @property + def count(self): + return len(self._tasks) + + +class Actor(object): + + def __init__(self, aid, env_fn): + self.aid = aid + self.env = env_fn() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + if done: + ob = self.env.reset() + return ob, reward, done, info + + def reset(self): + return self.env.reset() + + def get_spaces(self): + return self.env.observation_space, self.env.action_space + + def get_id(self): + return self.aid + + +class RemoteVecEnv(VecEnv): + def __init__(self, env_fns, spaces=None): + """ + envs: list of gym environments to run in subprocesses + """ + self.waiting = False + self.closed = False + self.task_pool = TaskPool(timeout=10) + + nenvs = len(env_fns) + + self.actors = [] + self.actor_to_i = {} + remote_actor = ray.remote(Actor) + for i in range(nenvs): + actor = remote_actor.remote(i, env_fns[i]) + self.actors.append(actor) + self.actor_to_i[actor] = i + + observation_space, action_space = ray.get(self.actors[0].get_spaces.remote()) + VecEnv.__init__(self, len(env_fns), observation_space, action_space) + + self.results = [([0] * OBSERVATION_SPACE, 0, False, {"bad": True})] * self.num_envs + + def step_async(self, actions): + for actor, action in zip(self.actors, actions): + # print(action, any(action)) + if any(action): + # print(self.actor_to_i[actor], action) + self.task_pool.add(actor, actor.step.remote(action)) + self.waiting = True + + def step_wait(self): + done_ids = set([]) + + count = 0 + while count * 2 < self.task_pool.count: + for actor, obj in self.task_pool.completed(): + _i = self.actor_to_i[actor] + done_ids.add(_i) + self.results[_i] = ray.get(obj) + count += 1 + # print(count, self.task_pool.count, done_ids) + self.waiting = False + obs, rews, dones, infos = zip(*self.results) + for i in range(self.num_envs): + infos[i]["bad"] = i not in done_ids + return np.stack(obs), np.stack(rews), np.stack(dones), infos + + def reset(self): + obj_ids = [actor.reset.remote() for actor in self.actors] + results = ray.get(ray.wait(obj_ids, num_returns=self.num_envs)[0]) + # TODO: should update self.results, but it's ok because this function will be invoked only at first + return np.stack(results) + + def close(self): + if self.closed: + return + self.closed = True