|
a |
|
b/drl/multiprocessing_env.py |
|
|
1 |
#This code is from openai baseline |
|
|
2 |
#https://github.com/openai/baselines/tree/master/baselines/common/vec_env |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
from multiprocessing import Process, Pipe |
|
|
6 |
|
|
|
7 |
def worker(remote, parent_remote, env_fn_wrapper): |
|
|
8 |
parent_remote.close() |
|
|
9 |
env = env_fn_wrapper.x() |
|
|
10 |
while True: |
|
|
11 |
cmd, data = remote.recv() |
|
|
12 |
if cmd == 'step': |
|
|
13 |
ob, reward, done, info = env.step(data) |
|
|
14 |
if done: |
|
|
15 |
ob = env.reset() |
|
|
16 |
remote.send((ob, reward, done, info)) |
|
|
17 |
elif cmd == 'reset': |
|
|
18 |
ob = env.reset() |
|
|
19 |
remote.send(ob) |
|
|
20 |
elif cmd == 'reset_task': |
|
|
21 |
ob = env.reset_task() |
|
|
22 |
remote.send(ob) |
|
|
23 |
elif cmd == 'close': |
|
|
24 |
remote.close() |
|
|
25 |
break |
|
|
26 |
elif cmd == 'get_spaces': |
|
|
27 |
remote.send((env.observation_space, env.action_space)) |
|
|
28 |
else: |
|
|
29 |
raise NotImplementedError |
|
|
30 |
|
|
|
31 |
class VecEnv(object): |
|
|
32 |
""" |
|
|
33 |
An abstract asynchronous, vectorized environment. |
|
|
34 |
""" |
|
|
35 |
def __init__(self, num_envs, observation_space, action_space): |
|
|
36 |
self.num_envs = num_envs |
|
|
37 |
self.observation_space = observation_space |
|
|
38 |
self.action_space = action_space |
|
|
39 |
|
|
|
40 |
def reset(self): |
|
|
41 |
""" |
|
|
42 |
Reset all the environments and return an array of |
|
|
43 |
observations, or a tuple of observation arrays. |
|
|
44 |
If step_async is still doing work, that work will |
|
|
45 |
be cancelled and step_wait() should not be called |
|
|
46 |
until step_async() is invoked again. |
|
|
47 |
""" |
|
|
48 |
pass |
|
|
49 |
|
|
|
50 |
def step_async(self, actions): |
|
|
51 |
""" |
|
|
52 |
Tell all the environments to start taking a step |
|
|
53 |
with the given actions. |
|
|
54 |
Call step_wait() to get the results of the step. |
|
|
55 |
You should not call this if a step_async run is |
|
|
56 |
already pending. |
|
|
57 |
""" |
|
|
58 |
pass |
|
|
59 |
|
|
|
60 |
def step_wait(self): |
|
|
61 |
""" |
|
|
62 |
Wait for the step taken with step_async(). |
|
|
63 |
Returns (obs, rews, dones, infos): |
|
|
64 |
- obs: an array of observations, or a tuple of |
|
|
65 |
arrays of observations. |
|
|
66 |
- rews: an array of rewards |
|
|
67 |
- dones: an array of "episode done" booleans |
|
|
68 |
- infos: a sequence of info objects |
|
|
69 |
""" |
|
|
70 |
pass |
|
|
71 |
|
|
|
72 |
def close(self): |
|
|
73 |
""" |
|
|
74 |
Clean up the environments' resources. |
|
|
75 |
""" |
|
|
76 |
pass |
|
|
77 |
|
|
|
78 |
def step(self, actions): |
|
|
79 |
self.step_async(actions) |
|
|
80 |
return self.step_wait() |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
class CloudpickleWrapper(object): |
|
|
84 |
""" |
|
|
85 |
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) |
|
|
86 |
""" |
|
|
87 |
def __init__(self, x): |
|
|
88 |
self.x = x |
|
|
89 |
def __getstate__(self): |
|
|
90 |
import cloudpickle |
|
|
91 |
return cloudpickle.dumps(self.x) |
|
|
92 |
def __setstate__(self, ob): |
|
|
93 |
import pickle |
|
|
94 |
self.x = pickle.loads(ob) |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
class SubprocVecEnv(VecEnv): |
|
|
98 |
def __init__(self, env_fns, spaces=None): |
|
|
99 |
""" |
|
|
100 |
envs: list of gym environments to run in subprocesses |
|
|
101 |
""" |
|
|
102 |
self.waiting = False |
|
|
103 |
self.closed = False |
|
|
104 |
nenvs = len(env_fns) |
|
|
105 |
self.nenvs = nenvs |
|
|
106 |
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) |
|
|
107 |
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) |
|
|
108 |
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] |
|
|
109 |
for p in self.ps: |
|
|
110 |
p.daemon = True # if the main process crashes, we should not cause things to hang |
|
|
111 |
p.start() |
|
|
112 |
for remote in self.work_remotes: |
|
|
113 |
remote.close() |
|
|
114 |
|
|
|
115 |
self.remotes[0].send(('get_spaces', None)) |
|
|
116 |
observation_space, action_space = self.remotes[0].recv() |
|
|
117 |
VecEnv.__init__(self, len(env_fns), observation_space, action_space) |
|
|
118 |
|
|
|
119 |
def step_async(self, actions): |
|
|
120 |
for remote, action in zip(self.remotes, actions): |
|
|
121 |
remote.send(('step', action)) |
|
|
122 |
self.waiting = True |
|
|
123 |
|
|
|
124 |
def step_wait(self): |
|
|
125 |
results = [remote.recv() for remote in self.remotes] |
|
|
126 |
self.waiting = False |
|
|
127 |
obs, rews, dones, infos = zip(*results) |
|
|
128 |
return np.stack(obs), np.stack(rews), np.stack(dones), infos |
|
|
129 |
|
|
|
130 |
def reset(self): |
|
|
131 |
for remote in self.remotes: |
|
|
132 |
remote.send(('reset', None)) |
|
|
133 |
return np.stack([remote.recv() for remote in self.remotes]) |
|
|
134 |
|
|
|
135 |
def reset_task(self): |
|
|
136 |
for remote in self.remotes: |
|
|
137 |
remote.send(('reset_task', None)) |
|
|
138 |
return np.stack([remote.recv() for remote in self.remotes]) |
|
|
139 |
|
|
|
140 |
def close(self): |
|
|
141 |
if self.closed: |
|
|
142 |
return |
|
|
143 |
if self.waiting: |
|
|
144 |
for remote in self.remotes: |
|
|
145 |
remote.recv() |
|
|
146 |
for remote in self.remotes: |
|
|
147 |
remote.send(('close', None)) |
|
|
148 |
for p in self.ps: |
|
|
149 |
p.join() |
|
|
150 |
self.closed = True |
|
|
151 |
|
|
|
152 |
def __len__(self): |
|
|
153 |
return self.nenvs |