[f9c9f2]: / baselines / common / cmd_util.py

Download this file

88 lines (79 with data), 3.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Helpers for scripts like run_atari.py.
"""
import os
import gym
from gym.wrappers import FlattenDictWrapper
from baselines import logger
from baselines.bench import Monitor
from baselines.common import set_global_seeds
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
"""
Create a wrapped, monitored SubprocVecEnv for Atari.
"""
if wrapper_kwargs is None: wrapper_kwargs = {}
def make_env(rank): # pylint: disable=C0111
def _thunk():
env = make_atari(env_id)
env.seed(seed + rank)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
return wrap_deepmind(env, **wrapper_kwargs)
return _thunk
set_global_seeds(seed)
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
def make_mujoco_env(env_id, seed):
"""
Create a wrapped, monitored gym.Env for MuJoCo.
"""
set_global_seeds(seed)
env = gym.make(env_id)
env = Monitor(env, logger.get_dir())
env.seed(seed)
return env
def make_robotics_env(env_id, seed, rank=0):
"""
Create a wrapped, monitored gym.Env for MuJoCo.
"""
set_global_seeds(seed)
env = gym.make(env_id)
env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',))
env.seed(seed)
return env
def arg_parser():
"""
Create an empty argparse.ArgumentParser.
"""
import argparse
return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
def atari_arg_parser():
"""
Create an argparse.ArgumentParser for run_atari.py.
"""
parser = arg_parser()
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
return parser
def mujoco_arg_parser():
"""
Create an argparse.ArgumentParser for run_mujoco.py.
"""
parser = arg_parser()
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
return parser
def robotics_arg_parser():
"""
Create an argparse.ArgumentParser for run_mujoco.py.
"""
parser = arg_parser()
parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
return parser