Switch to unified view

a b/drl/multiprocessing_env.py
1
#This code is from openai baseline
2
#https://github.com/openai/baselines/tree/master/baselines/common/vec_env
3
4
import numpy as np
5
from multiprocessing import Process, Pipe
6
7
def worker(remote, parent_remote, env_fn_wrapper):
8
    parent_remote.close()
9
    env = env_fn_wrapper.x()
10
    while True:
11
        cmd, data = remote.recv()
12
        if cmd == 'step':
13
            ob, reward, done, info = env.step(data)
14
            if done:
15
                ob = env.reset()
16
            remote.send((ob, reward, done, info))
17
        elif cmd == 'reset':
18
            ob = env.reset()
19
            remote.send(ob)
20
        elif cmd == 'reset_task':
21
            ob = env.reset_task()
22
            remote.send(ob)
23
        elif cmd == 'close':
24
            remote.close()
25
            break
26
        elif cmd == 'get_spaces':
27
            remote.send((env.observation_space, env.action_space))
28
        else:
29
            raise NotImplementedError
30
31
class VecEnv(object):
32
    """
33
    An abstract asynchronous, vectorized environment.
34
    """
35
    def __init__(self, num_envs, observation_space, action_space):
36
        self.num_envs = num_envs
37
        self.observation_space = observation_space
38
        self.action_space = action_space
39
40
    def reset(self):
41
        """
42
        Reset all the environments and return an array of
43
        observations, or a tuple of observation arrays.
44
        If step_async is still doing work, that work will
45
        be cancelled and step_wait() should not be called
46
        until step_async() is invoked again.
47
        """
48
        pass
49
50
    def step_async(self, actions):
51
        """
52
        Tell all the environments to start taking a step
53
        with the given actions.
54
        Call step_wait() to get the results of the step.
55
        You should not call this if a step_async run is
56
        already pending.
57
        """
58
        pass
59
60
    def step_wait(self):
61
        """
62
        Wait for the step taken with step_async().
63
        Returns (obs, rews, dones, infos):
64
         - obs: an array of observations, or a tuple of
65
                arrays of observations.
66
         - rews: an array of rewards
67
         - dones: an array of "episode done" booleans
68
         - infos: a sequence of info objects
69
        """
70
        pass
71
72
    def close(self):
73
        """
74
        Clean up the environments' resources.
75
        """
76
        pass
77
78
    def step(self, actions):
79
        self.step_async(actions)
80
        return self.step_wait()
81
82
    
83
class CloudpickleWrapper(object):
84
    """
85
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
86
    """
87
    def __init__(self, x):
88
        self.x = x
89
    def __getstate__(self):
90
        import cloudpickle
91
        return cloudpickle.dumps(self.x)
92
    def __setstate__(self, ob):
93
        import pickle
94
        self.x = pickle.loads(ob)
95
96
        
97
class SubprocVecEnv(VecEnv):
98
    def __init__(self, env_fns, spaces=None):
99
        """
100
        envs: list of gym environments to run in subprocesses
101
        """
102
        self.waiting = False
103
        self.closed = False
104
        nenvs = len(env_fns)
105
        self.nenvs = nenvs
106
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
107
        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
108
            for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
109
        for p in self.ps:
110
            p.daemon = True # if the main process crashes, we should not cause things to hang
111
            p.start()
112
        for remote in self.work_remotes:
113
            remote.close()
114
115
        self.remotes[0].send(('get_spaces', None))
116
        observation_space, action_space = self.remotes[0].recv()
117
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)
118
119
    def step_async(self, actions):
120
        for remote, action in zip(self.remotes, actions):
121
            remote.send(('step', action))
122
        self.waiting = True
123
124
    def step_wait(self):
125
        results = [remote.recv() for remote in self.remotes]
126
        self.waiting = False
127
        obs, rews, dones, infos = zip(*results)
128
        return np.stack(obs), np.stack(rews), np.stack(dones), infos
129
130
    def reset(self):
131
        for remote in self.remotes:
132
            remote.send(('reset', None))
133
        return np.stack([remote.recv() for remote in self.remotes])
134
135
    def reset_task(self):
136
        for remote in self.remotes:
137
            remote.send(('reset_task', None))
138
        return np.stack([remote.recv() for remote in self.remotes])
139
140
    def close(self):
141
        if self.closed:
142
            return
143
        if self.waiting:
144
            for remote in self.remotes:            
145
                remote.recv()
146
        for remote in self.remotes:
147
            remote.send(('close', None))
148
        for p in self.ps:
149
            p.join()
150
            self.closed = True
151
            
152
    def __len__(self):
153
        return self.nenvs