a b/baselines/a2c/a2c.py
1
import os.path as osp
2
import time
3
import joblib
4
import numpy as np
5
import tensorflow as tf
6
from baselines import logger
7
8
from baselines.common import set_global_seeds, explained_variance
9
from baselines.common.runners import AbstractEnvRunner
10
from baselines.common import tf_util
11
12
from baselines.a2c.utils import discount_with_dones
13
from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables
14
from baselines.a2c.utils import cat_entropy, mse
15
16
class Model(object):
17
18
    def __init__(self, policy, ob_space, ac_space, nenvs, nsteps,
19
            ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
20
            alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
21
22
        sess = tf_util.make_session()
23
        nbatch = nenvs*nsteps
24
25
        A = tf.placeholder(tf.int32, [nbatch])
26
        ADV = tf.placeholder(tf.float32, [nbatch])
27
        R = tf.placeholder(tf.float32, [nbatch])
28
        LR = tf.placeholder(tf.float32, [])
29
30
        step_model = policy(sess, ob_space, ac_space, nenvs, 1, reuse=False)
31
        train_model = policy(sess, ob_space, ac_space, nenvs*nsteps, nsteps, reuse=True)
32
33
        neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.pi, labels=A)
34
        pg_loss = tf.reduce_mean(ADV * neglogpac)
35
        vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.vf), R))
36
        entropy = tf.reduce_mean(cat_entropy(train_model.pi))
37
        loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
38
39
        params = find_trainable_variables("model")
40
        grads = tf.gradients(loss, params)
41
        if max_grad_norm is not None:
42
            grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
43
        grads = list(zip(grads, params))
44
        trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
45
        _train = trainer.apply_gradients(grads)
46
47
        lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
48
49
        def train(obs, states, rewards, masks, actions, values):
50
            advs = rewards - values
51
            for step in range(len(obs)):
52
                cur_lr = lr.value()
53
            td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
54
            if states is not None:
55
                td_map[train_model.S] = states
56
                td_map[train_model.M] = masks
57
            policy_loss, value_loss, policy_entropy, _ = sess.run(
58
                [pg_loss, vf_loss, entropy, _train],
59
                td_map
60
            )
61
            return policy_loss, value_loss, policy_entropy
62
63
        def save(save_path):
64
            ps = sess.run(params)
65
            make_path(osp.dirname(save_path))
66
            joblib.dump(ps, save_path)
67
68
        def load(load_path):
69
            loaded_params = joblib.load(load_path)
70
            restores = []
71
            for p, loaded_p in zip(params, loaded_params):
72
                restores.append(p.assign(loaded_p))
73
            sess.run(restores)
74
75
        self.train = train
76
        self.train_model = train_model
77
        self.step_model = step_model
78
        self.step = step_model.step
79
        self.value = step_model.value
80
        self.initial_state = step_model.initial_state
81
        self.save = save
82
        self.load = load
83
        tf.global_variables_initializer().run(session=sess)
84
85
class Runner(AbstractEnvRunner):
86
87
    def __init__(self, env, model, nsteps=5, gamma=0.99):
88
        super().__init__(env=env, model=model, nsteps=nsteps)
89
        self.gamma = gamma
90
91
    def run(self):
92
        mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
93
        mb_states = self.states
94
        for n in range(self.nsteps):
95
            actions, values, states, _ = self.model.step(self.obs, self.states, self.dones)
96
            mb_obs.append(np.copy(self.obs))
97
            mb_actions.append(actions)
98
            mb_values.append(values)
99
            mb_dones.append(self.dones)
100
            obs, rewards, dones, _ = self.env.step(actions)
101
            self.states = states
102
            self.dones = dones
103
            for n, done in enumerate(dones):
104
                if done:
105
                    self.obs[n] = self.obs[n]*0
106
            self.obs = obs
107
            mb_rewards.append(rewards)
108
        mb_dones.append(self.dones)
109
        #batch of steps to batch of rollouts
110
        mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0).reshape(self.batch_ob_shape)
111
        mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
112
        mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
113
        mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
114
        mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
115
        mb_masks = mb_dones[:, :-1]
116
        mb_dones = mb_dones[:, 1:]
117
        last_values = self.model.value(self.obs, self.states, self.dones).tolist()
118
        #discount/bootstrap off value fn
119
        for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
120
            rewards = rewards.tolist()
121
            dones = dones.tolist()
122
            if dones[-1] == 0:
123
                rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1]
124
            else:
125
                rewards = discount_with_dones(rewards, dones, self.gamma)
126
            mb_rewards[n] = rewards
127
        mb_rewards = mb_rewards.flatten()
128
        mb_actions = mb_actions.flatten()
129
        mb_values = mb_values.flatten()
130
        mb_masks = mb_masks.flatten()
131
        return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values
132
133
def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100):
134
    tf.reset_default_graph()
135
    set_global_seeds(seed)
136
137
    nenvs = env.num_envs
138
    ob_space = env.observation_space
139
    ac_space = env.action_space
140
    model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
141
        max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
142
    runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
143
144
    nbatch = nenvs*nsteps
145
    tstart = time.time()
146
    for update in range(1, total_timesteps//nbatch+1):
147
        obs, states, rewards, masks, actions, values = runner.run()
148
        policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
149
        nseconds = time.time()-tstart
150
        fps = int((update*nbatch)/nseconds)
151
        if update % log_interval == 0 or update == 1:
152
            ev = explained_variance(values, rewards)
153
            logger.record_tabular("nupdates", update)
154
            logger.record_tabular("total_timesteps", update*nbatch)
155
            logger.record_tabular("fps", fps)
156
            logger.record_tabular("policy_entropy", float(policy_entropy))
157
            logger.record_tabular("value_loss", float(value_loss))
158
            logger.record_tabular("explained_variance", float(ev))
159
            logger.dump_tabular()
160
    env.close()