--- a +++ b/nips/ppo_agent.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import baselines.ppo2.ppo2 as ppo2 +import baselines.ppo2.policies as policies +from nips.custom_env import make_env +from nips.remote_vec_env import RemoteVecEnv +from baselines.logger import configure +from baselines.common import set_global_seeds +from baselines.common.vec_env.vec_normalize import VecNormalize + +import tensorflow as tf +import ray + + +def train(num_timesteps, seed): + num_cpus = 1 + num_casks = 1 + num_cpus += num_casks + + config = tf.ConfigProto(allow_soft_placement=True, + intra_op_parallelism_threads=num_cpus, + inter_op_parallelism_threads=num_cpus) + tf.Session(config=config).__enter__() + + gamma = 0.995 + + env = RemoteVecEnv([make_env] * num_cpus) + env = VecNormalize(env, ret=True, gamma=gamma) + + set_global_seeds(seed) + policy = policies.MlpPolicy + ppo2.learn(policy=policy, + env=env, + nsteps=128, + nminibatches=num_cpus-num_casks, + lam=0.95, + gamma=gamma, + noptepochs=4, + log_interval=1, + vf_coef=0.5, + ent_coef=0.0, + lr=3e-4, + cliprange=0.2, + save_interval=2, + load_path="./logs/course_6/00244", + total_timesteps=num_timesteps, + num_casks=num_casks) + + +if __name__ == "__main__": + ray.init() + configure(dir="./logs") + train(int(1e6), 60730)