[f9c9f2]: / submission / submit.py

Download this file

40 lines (34 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from submission.baselines.common.vec_env.vec_normalize import VecNormalize
from submission.baselines.common.vec_env.dummy_vec_env import DummyVecEnv
import submission.baselines.ppo2.ppo2 as ppo2
import submission.baselines.ppo2.policies as policies
from submission.submit_env import make_submit_env
from submission.baselines.logger import configure
from submission.baselines.common import set_global_seeds
import tensorflow as tf
def submit(num_timesteps, seed):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_global_seeds(seed)
with tf.Session(config=config):
env = DummyVecEnv([make_submit_env])
env = VecNormalize(env)
ppo2.learn(policy=policies.MlpPolicy,
env=env,
nsteps=10000,
nminibatches=16,
lam=0.95,
gamma=0.99,
noptepochs=5,
log_interval=1,
vf_coef=0.5,
ent_coef=0.0,
lr=3e-4,
cliprange=0.2,
save_interval=10,
load_path="./logs/course_7/00036",
total_timesteps=num_timesteps
)
if __name__ == '__main__':
configure(dir="./logs")
submit(int(1e6), 60730)