Diff of /submission/submit.py [000000] .. [f9c9f2]

Switch to unified view

a b/submission/submit.py
1
from submission.baselines.common.vec_env.vec_normalize import VecNormalize
2
from submission.baselines.common.vec_env.dummy_vec_env import DummyVecEnv
3
import submission.baselines.ppo2.ppo2 as ppo2
4
import submission.baselines.ppo2.policies as policies
5
from submission.submit_env import make_submit_env
6
from submission.baselines.logger import configure
7
from submission.baselines.common import set_global_seeds
8
9
import tensorflow as tf
10
11
12
def submit(num_timesteps, seed):
13
    config = tf.ConfigProto()
14
    config.gpu_options.allow_growth = True
15
16
    set_global_seeds(seed)
17
    with tf.Session(config=config):
18
        env = DummyVecEnv([make_submit_env])
19
        env = VecNormalize(env)
20
        ppo2.learn(policy=policies.MlpPolicy,
21
                   env=env,
22
                   nsteps=10000,
23
                   nminibatches=16,
24
                   lam=0.95,
25
                   gamma=0.99,
26
                   noptepochs=5,
27
                   log_interval=1,
28
                   vf_coef=0.5,
29
                   ent_coef=0.0,
30
                   lr=3e-4,
31
                   cliprange=0.2,
32
                   save_interval=10,
33
                   load_path="./logs/course_7/00036",
34
                   total_timesteps=num_timesteps
35
                   )
36
37
if __name__ == '__main__':
38
    configure(dir="./logs")
39
    submit(int(1e6), 60730)