[9f010e]: / mouse_scripts / SAC / sac.py

Download this file

327 lines (246 with data), 17.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os
import torch
import torch.nn.functional as F
from torch.optim import Adam, AdamW, RMSprop
from .utils1 import soft_update, hard_update
from .model import GaussianPolicyLSTM, GaussianPolicyRNN, QNetworkFF, QNetworkLSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
class SAC(object):
def __init__(self, num_inputs, action_space, args):
self.gamma = args.gamma
self.tau = args.tau
self.alpha = args.alpha
self.hidden_size= args.hidden_size
self.automatic_entropy_tuning = args.automatic_entropy_tuning
self.device = torch.device("cuda")
# Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
if args.automatic_entropy_tuning:
self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
def select_action(self, state, h_prev, c_prev, evaluate=False):
pass
def update_parameters(self, policy_memory, policy_batch_size):
pass
# Save model parameters
def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
if not os.path.exists('models/'):
os.makedirs('models/')
if actor_path is None:
actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
if critic_path is None:
critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
print('Saving models to {} and {}'.format(actor_path, critic_path))
torch.save(self.policy.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
# Load model parameters
def load_model(self, actor_path, critic_path):
print('Loading models from {} and {}'.format(actor_path, critic_path))
if actor_path is not None:
self.policy.load_state_dict(torch.load(actor_path))
if critic_path is not None:
self.critic.load_state_dict(torch.load(critic_path))
#filter_padded takes in a padded sequence of size (B, L_max, H) and corresponding sequence lengths, and returns a tensor of size [max(seq_lens), H]
#after filtering redundant paddings.
def filter_padded(self, padded_seq, seq_lens):
# padded_seq = a tensor of size (batch_size, max_seq_len, input_dimension) i.e. (B, L_max, H) representing a padded object
# seq_lens = a list contatining the length of individual sequences in the sequence object before padding
seq_max = max(seq_lens)
#reshape padded sequence to (B*L_max, input_dimension)
t = padded_seq.reshape(padded_seq.shape[0]*padded_seq.shape[1], padded_seq.shape[2])
iter_max = int(t.shape[0]/seq_max)
for iter1 in range(iter_max):
k = [item for item in range(iter1*seq_max, (iter1+1)*seq_max)]
k = k[:seq_lens[iter1]]
if iter1 == 0:
out_t = t[k]
else:
out_t = torch.cat((out_t, t[k]), dim=0)
return out_t
class SACRNN(SAC):
def __init__(self, num_inputs, action_space, args):
super(SACRNN, self).__init__(num_inputs, action_space, args)
self.critic = QNetworkFF(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
self.critic_target = QNetworkFF(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
hard_update(self.critic_target, self.critic)
self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
self.policy = GaussianPolicyRNN(num_inputs, action_space.shape[0], args.hidden_size, action_space=None).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
def select_action(self, state, h_prev, c_prev, evaluate=False):
state = torch.FloatTensor(state).to(self.device).unsqueeze(0).unsqueeze(0)
h_prev = h_prev.to(self.device)
c_prev = c_prev.to(self.device)
if evaluate == False:
action, _, _, h_current, c_current, _, lstm_out = self.policy.sample(state, h_prev, c_prev, sampling=True, len_seq=None)
else:
_, _, action, h_current, c_current, _, lstm_out = self.policy.sample(state, h_prev, c_prev, sampling=True, len_seq=None)
return action.detach().cpu().numpy()[0], h_current.detach(), c_current.detach(), lstm_out.detach().cpu().numpy()
def update_parameters(self, policy_memory, policy_batch_size):
# Sample a batch from memory
state_batch, action_batch, reward_batch, next_state_batch, mask_batch, h_batch, c_batch, policy_state_batch = policy_memory.sample(batch_size=policy_batch_size)
state_batch = torch.FloatTensor(state_batch).to(self.device)
next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
action_batch = torch.FloatTensor(action_batch).to(self.device)
reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
h_batch = torch.FloatTensor(h_batch).to(self.device).permute(1, 0, 2)
c_batch = torch.FloatTensor(c_batch).to(self.device)
with torch.no_grad():
next_state_action, next_state_log_pi, _, _, _, _, _ = self.policy.sample(next_state_batch.unsqueeze(1), h_batch, c_batch, sampling=True)
qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf_loss = qf1_loss + qf2_loss
self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
# Update the policy network using the newly proposed method
h0 = torch.zeros(size=(1, len(policy_state_batch), self.hidden_size)).to(self.device)
c0 = torch.zeros(size=(1, len(policy_state_batch), self.hidden_size)).to(self.device)
len_seq = list(map(len, policy_state_batch))
policy_state_batch = torch.FloatTensor(pad_sequence(policy_state_batch, batch_first= True)).to(self.device)
pi_action_bat, log_prob_bat, _, _, _, mask_seq, _ = self.policy.sample(policy_state_batch, h0, c0, sampling= False, len_seq= len_seq)
#Now mask the policy_state_batch according to the mask seq
policy_state_batch_pi= policy_state_batch.reshape(-1, policy_state_batch.size()[-1])[mask_seq]
qf1_pi, qf2_pi = self.critic(policy_state_batch_pi, pi_action_bat)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_prob_bat) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
# Sample the hidden weights of the RNN
J_lstm_w = self.policy.lstm.weight_hh_l0 #These weights would be of the size (hidden_dim, hidden_dim)
#Sample the output of the RNN for the policy_state_batch
lstm_out_r, _ = self.policy.forward_for_simple_dynamics(policy_state_batch, h0, c0, sampling=False, len_seq= len_seq)
lstm_out_r = lstm_out_r.reshape(-1, lstm_out_r.size()[-1])[mask_seq]
#Reshape the policy hidden weights vector
J_lstm_w = J_lstm_w.unsqueeze(0).repeat(lstm_out_r.size()[0], 1, 1)
lstm_out_r = 1 - torch.pow(lstm_out_r, 2)
R_j = torch.mul(J_lstm_w, lstm_out_r.unsqueeze(-1))
policy_loss_2 = torch.norm(R_j)**2
#Find the loss encouraging the minimization of the firing rates for the linear and the RNN layer
#Sample the output of the RNN for the policy_state_batch
lstm_out_r, linear_out = self.policy.forward_for_simple_dynamics(policy_state_batch, h0, c0, sampling=False, len_seq= len_seq)
lstm_out_r = lstm_out_r.reshape(-1, lstm_out_r.size()[-1])[mask_seq]
linear_out = linear_out.reshape(-1, linear_out.size()[-1])[mask_seq]
mean_out_emg, _, _, _, _ = self.policy.forward(policy_state_batch, h0, c0, sampling=False, len_seq=len_seq)
mean_out_emg = mean_out_emg.reshape(-1, mean_out_emg.size()[-1])[mask_seq]
policy_loss_3 = torch.norm(lstm_out_r)**2 + torch.norm(linear_out)**2 + torch.norm(mean_out_emg)**2
#Find the loss encouraging the minimization of the input and output weights of the LSTM(RNN) and the layers downstream
#and upstream of the LSTM
#Sample the input weights of the RNN
J_lstm_i = self.policy.lstm.weight_ih_l0
J_in1 = self.policy.linear1.weight
#Sample the output weights
# J_out = self.policy.linear2.weight
J_out1 = self.policy.mean_linear.weight
policy_loss_4 = torch.norm(J_in1)**2 + torch.norm(J_lstm_i)**2 + torch.norm(J_out1)**2
policy_loss += (0.003*(policy_loss_2)) + (0.001*(policy_loss_3)) + (0.001*(policy_loss_4))
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
log_pi = log_prob_bat
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
alpha_tlogs = self.alpha.clone() # For TensorboardX logs
else:
alpha_loss = torch.tensor(0.).to(self.device)
alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
soft_update(self.critic_target, self.critic, self.tau)
return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), policy_loss_2.item(), policy_loss_3.item(), policy_loss_4.item(), alpha_loss.item(), alpha_tlogs.item()
class SACLSTM(SAC):
def __init__(self, num_inputs, action_space, args):
super(SACLSTM, self).__init__(num_inputs, action_space, args)
self.critic = QNetworkLSTM(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
self.critic_target = QNetworkLSTM(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
hard_update(self.critic_target, self.critic)
self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
self.policy = GaussianPolicyLSTM(num_inputs, action_space.shape[0], args.hidden_size, action_space=None).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
def select_action(self, state, h_prev, c_prev, evaluate=False):
state = torch.FloatTensor(state).to(self.device).unsqueeze(0).unsqueeze(0)
h_prev = h_prev.to(self.device)
c_prev = c_prev.to(self.device)
if evaluate == False:
action, _, _, h_current, c_current, lstm_out = self.policy.sample(state, h_prev, c_prev, sampling=True)
else:
_, _, action, h_current, c_current, lstm_out = self.policy.sample(state, h_prev, c_prev, sampling=True)
return action.detach().cpu().numpy()[0], h_current.detach(), c_current.detach(), lstm_out.detach().cpu().numpy()
def update_parameters(self, policy_memory, policy_batch_size):
# Sample a batch from memory
#state_batch_p means padded_batch state_batch1 in notes
#state_batch means packed batch state_batch in notes
state_batch_0, action_batch_0, reward_batch_0, next_state_batch_0, mask_batch_0, hidden_in, hidden_out = policy_memory.sample(batch_size=policy_batch_size)
seq_lengths= list(map(len, state_batch_0))
state_batch_p = pad_sequence(state_batch_0, batch_first= True)
action_batch_p = pad_sequence(action_batch_0, batch_first= True)
reward_batch_p = pad_sequence(reward_batch_0, batch_first= True)
next_state_batch_p = pad_sequence(next_state_batch_0, batch_first= True)
mask_batch_p = pad_sequence(mask_batch_0, batch_first= True)
state_batch_p = torch.FloatTensor(state_batch_p).to(self.device)
next_state_batch_p = torch.FloatTensor(next_state_batch_p).to(self.device)
action_batch_p = torch.FloatTensor(action_batch_p).to(self.device)
reward_batch_p = torch.FloatTensor(reward_batch_p).to(self.device)
mask_batch_p = torch.FloatTensor(mask_batch_p).to(self.device)
hidden_in = (hidden_in[0].to(self.device), hidden_in[1].to(self.device))
hidden_out = (hidden_out[0].to(self.device), hidden_out[1].to(self.device))
state_batch = pack_padded_sequence(state_batch_p, seq_lengths, batch_first= True, enforce_sorted= False)
next_state_batch = pack_padded_sequence(next_state_batch_p, seq_lengths, batch_first= True, enforce_sorted= False)
action_batch = pack_padded_sequence(action_batch_p, seq_lengths, batch_first= True, enforce_sorted= False)
reward_batch_pack = pack_padded_sequence(reward_batch_p, seq_lengths, batch_first= True, enforce_sorted= False)
mask_batch_pack = pack_padded_sequence(mask_batch_p, seq_lengths, batch_first= True, enforce_sorted= False)
reward_batch = self.filter_padded(reward_batch_p, seq_lengths)
mask_batch = self.filter_padded(mask_batch_p, seq_lengths)
# We have padded batches of state, action, reward, next_state and mask from here downwards. We also have corresponding sequence lengths seq_lens
# batch_p stands for padded batch or tensor of size (B, L_max, H)
with torch.no_grad():
next_state_action_p, next_state_log_pi_p, _, _, _, _ = self.policy.sample(next_state_batch, h_prev=hidden_out[0], c_prev= hidden_out[1], sampling= False)
next_state_state_action_p = torch.cat((next_state_batch_p, next_state_action_p), dim=2)
next_state_state_action = pack_padded_sequence(next_state_state_action_p, seq_lengths, batch_first= True, enforce_sorted= False)
qf1_next_target, qf2_next_target = self.critic_target(next_state_state_action, hidden_out)
qf1_next_target = self.filter_padded(qf1_next_target, seq_lengths)
qf2_next_target = self.filter_padded(qf2_next_target, seq_lengths)
next_state_log_pi = self.filter_padded(next_state_log_pi_p, seq_lengths)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
state_action_batch_p = torch.cat((state_batch_p, action_batch_p), dim=2)
state_action_batch = pack_padded_sequence(state_action_batch_p, seq_lengths, batch_first= True, enforce_sorted= False)
qf1_p, qf2_p = self.critic(state_action_batch, hidden_in) # Two Q-functions to mitigate positive bias in the policy improvement step
qf1 = self.filter_padded(qf1_p, seq_lengths)
qf2 = self.filter_padded(qf2_p, seq_lengths)
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf_loss = qf1_loss + qf2_loss
self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
# Update the policy network using the newly proposed method
pi_action_bat_p, log_prob_bat_p, _, _, _, _ = self.policy.sample(state_batch, h_prev= hidden_in[0], c_prev= hidden_in[1], sampling= False)
pi_state_action_batch_p = torch.cat((state_batch_p, pi_action_bat_p), dim=2)
pi_state_action_batch = pack_padded_sequence(pi_state_action_batch_p, seq_lengths, batch_first= True, enforce_sorted= False)
qf1_pi_p, qf2_pi_p = self.critic(pi_state_action_batch, hidden_in)
qf1_pi = self.filter_padded(qf1_pi_p, seq_lengths)
qf2_pi = self.filter_padded(qf2_pi_p, seq_lengths)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
log_prob_bat = self.filter_padded(log_prob_bat_p, seq_lengths)
policy_loss = ((self.alpha * log_prob_bat) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
log_pi = log_prob_bat
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
alpha_tlogs = self.alpha.clone() # For TensorboardX logs
else:
alpha_loss = torch.tensor(0.).to(self.device)
alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
soft_update(self.critic_target, self.critic, self.tau)
return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()