--- a
+++ b/nips/round2_train.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python
+
+import argparse
+import ray
+import tensorflow as tf
+from baselines.logger import configure
+from baselines.common import set_global_seeds
+from baselines.common.vec_env.vec_normalize import VecNormalize
+import baselines.ppo2.policies as policies
+import baselines.ppo2.ppo2 as ppo2
+from nips.round2_env import CustomEnv, CustomActionWrapper
+from nips.remote_vec_env import RemoteVecEnv
+
+
+def create_env():
+    env = CustomEnv(visualization=args.vis, integrator_accuracy=args.accuracy)
+    env = CustomActionWrapper(env, action_repeat=args.repeat)
+    return env
+
+
+def train():
+    config = tf.ConfigProto(
+        allow_soft_placement=True,
+        intra_op_parallelism_threads=args.num_cpus,
+        inter_op_parallelism_threads=args.num_cpus
+    )
+    tf.Session(config=config).__enter__()
+
+    env = RemoteVecEnv([create_env] * args.num_cpus)
+    env = VecNormalize(env, ret=True, gamma=args.gamma)
+
+    ppo2.learn(
+        policy=policies.MlpPolicy, env=env,
+        total_timesteps=args.num_timesteps, nminibatches=args.num_minibatches,
+        nsteps=args.num_steps, noptepochs=args.num_epochs, lr=args.learning_rate,
+        gamma=args.gamma,
+        lam=args.lam, ent_coef=args.ent_coef, vf_coef=args.vf_coef, cliprange=args.clip_range,
+        log_interval=args.log_interval, save_interval=args.save_interval,
+        load_path=args.checkpoint_path,
+        num_casks=args.num_casks
+    )
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='PPO training for NIPS 2018: AI for Prosthetics Challenge, Round 2')
+    parser.add_argument('--num-timesteps', default=1e6, type=int, help='number of timesteps')
+    # batch_size = num_steps * num_envs
+    parser.add_argument('--num-steps', default=128, type=int, help='number of steps per update')
+    # valid number of envs
+    parser.add_argument('--num-minibatches', default=1, type=int, help='number of training minibatches per update')
+    parser.add_argument('--num-epochs', default=4, type=int, help='number of training epochs per update')
+    parser.add_argument('--learning-rate', default=3e-4, type=float, help='learning rate')
+    # RL domain
+    parser.add_argument('--gamma', default=0.99, type=float, help='discounting factor')
+    # PPO specific
+    parser.add_argument('--lam', default=0.95, type=float, help='advantage estimation discounting factor')
+    parser.add_argument('--ent-coef', default=0.001, type=float, help='policy entropy coefficient')
+    parser.add_argument('--vf-coef', default=0.5, type=float, help='value function loss coefficient')
+    parser.add_argument('--clip-range', default=0.2, type=float, help='clipping range')
+    # env related
+    parser.add_argument('--seed', default=6730, type=int, help='random seed')
+    parser.add_argument('--accuracy', default=5e-5, type=float, help='simulator integrator accuracy')
+    parser.add_argument('--repeat', default=1, type=int, help='number of action repeat')
+    parser.add_argument('--vis', default=False, action='store_true', help='visualization option')
+    # training settings
+    parser.add_argument('--num-cpus', default=1, type=int, help='number of cpus')
+    parser.add_argument('--num-casks', default=0, type=int, help='number of casks, for acceleration')
+    parser.add_argument('--num-gpus', default=0, type=int, help='number of gpus')
+    parser.add_argument('--log-dir', default='./logs', type=str, help='logging events output directory')
+    parser.add_argument('--log-interval', default=1, type=int, help='number of timesteps between logging events')
+    parser.add_argument('--save-interval', default=1, type=int, help='number of timesteps between saving events')
+    parser.add_argument('--checkpoint-path', default=None, type=str, help='path to load the model checkpoint from')
+    args = parser.parse_args()
+    print(args)
+
+    ray.init(num_cpus=args.num_cpus, num_gpus=args.num_gpus)
+    set_global_seeds(args.seed)
+    configure(dir=args.log_dir)
+    train()