|
a |
|
b/SAC/sac.py |
|
|
1 |
import os |
|
|
2 |
import torch |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
from torch.optim import Adam |
|
|
5 |
from .utils import soft_update, hard_update |
|
|
6 |
from .model import Actor, Critic |
|
|
7 |
from torch.nn.utils.rnn import pad_sequence |
|
|
8 |
import numpy as np |
|
|
9 |
from .replay_memory import PolicyReplayMemory |
|
|
10 |
import ipdb |
|
|
11 |
|
|
|
12 |
class SAC_Agent(): |
|
|
13 |
def __init__(self, |
|
|
14 |
num_inputs: int, |
|
|
15 |
action_space: int, |
|
|
16 |
hidden_size: int, |
|
|
17 |
lr: float, |
|
|
18 |
gamma: float, |
|
|
19 |
tau: float, |
|
|
20 |
alpha: float, |
|
|
21 |
automatic_entropy_tuning: bool, |
|
|
22 |
model: str, |
|
|
23 |
multi_policy_loss: bool, |
|
|
24 |
alpha_usim:float, |
|
|
25 |
beta_usim:float, |
|
|
26 |
gamma_usim:float, |
|
|
27 |
zeta_nusim:float, |
|
|
28 |
cuda: bool): |
|
|
29 |
|
|
|
30 |
if cuda: |
|
|
31 |
self.device = torch.device("cuda") |
|
|
32 |
else: |
|
|
33 |
self.device = torch.device("cpu") |
|
|
34 |
|
|
|
35 |
#Save the regularization parameters |
|
|
36 |
self.alpha_usim = alpha_usim |
|
|
37 |
self.beta_usim = beta_usim |
|
|
38 |
self.gamma_usim = gamma_usim |
|
|
39 |
self.zeta_nusim = zeta_nusim |
|
|
40 |
|
|
|
41 |
### SET CRITIC NETWORKS ### |
|
|
42 |
self.critic = Critic(num_inputs, action_space.shape[0], hidden_size).to(self.device) |
|
|
43 |
self.critic_target = Critic(num_inputs, action_space.shape[0], hidden_size).to(self.device) |
|
|
44 |
self.critic_optim = Adam(self.critic.parameters(), lr=lr) |
|
|
45 |
hard_update(self.critic_target, self.critic) |
|
|
46 |
|
|
|
47 |
### SET ACTOR NETWORK ### |
|
|
48 |
self.actor = Actor(num_inputs, action_space.shape[0], hidden_size, model, action_space=None).to(self.device) |
|
|
49 |
self.actor_optim = Adam(self.actor.parameters(), lr=lr) |
|
|
50 |
|
|
|
51 |
### SET TRAINING VARIABLES ### |
|
|
52 |
self.model = model |
|
|
53 |
self.multi_policy_loss = multi_policy_loss |
|
|
54 |
self.gamma = gamma |
|
|
55 |
self.tau = tau |
|
|
56 |
self.alpha = alpha |
|
|
57 |
self.hidden_size= hidden_size |
|
|
58 |
self.automatic_entropy_tuning = automatic_entropy_tuning |
|
|
59 |
|
|
|
60 |
# Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper |
|
|
61 |
if automatic_entropy_tuning: |
|
|
62 |
self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item() |
|
|
63 |
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) |
|
|
64 |
self.alpha_optim = Adam([self.log_alpha], lr=lr) |
|
|
65 |
|
|
|
66 |
#This loss encourages the simple low-dimensional dynamics in the RNN activity |
|
|
67 |
def _policy_loss_2(self, policy_state_batch, h0, len_seq, mask_seq): |
|
|
68 |
|
|
|
69 |
# Sample the hidden weights of the RNN |
|
|
70 |
J_rnn_w = self.actor.rnn.weight_hh_l0 #These weights would be of the size (hidden_dim, hidden_dim) |
|
|
71 |
|
|
|
72 |
#Sample the output of the RNN for the policy_state_batch |
|
|
73 |
rnn_out_r, _ = self.actor.forward_for_simple_dynamics(policy_state_batch, h0, sampling=False, len_seq=len_seq) |
|
|
74 |
rnn_out_r = rnn_out_r.reshape(-1, rnn_out_r.size()[-1])[mask_seq] |
|
|
75 |
|
|
|
76 |
#Reshape the policy hidden weights vector |
|
|
77 |
J_rnn_w = J_rnn_w.unsqueeze(0).repeat(rnn_out_r.size()[0], 1, 1) |
|
|
78 |
rnn_out_r = 1 - torch.pow(rnn_out_r, 2) |
|
|
79 |
|
|
|
80 |
R_j = torch.mul(J_rnn_w, rnn_out_r.unsqueeze(-1)) |
|
|
81 |
|
|
|
82 |
policy_loss_2 = torch.norm(R_j)**2 |
|
|
83 |
|
|
|
84 |
return policy_loss_2 |
|
|
85 |
|
|
|
86 |
#This loss encourages the minimization of the firing rates for the linear and the RNN layer. |
|
|
87 |
def _policy_loss_3(self, policy_state_batch, h0, len_seq, mask_seq): |
|
|
88 |
|
|
|
89 |
#Find the loss encouraging the minimization of the firing rates for the linear and the RNN layer |
|
|
90 |
#Sample the output of the RNN for the policy_state_batch |
|
|
91 |
rnn_out_r, linear_out = self.actor.forward_for_simple_dynamics(policy_state_batch, h0, sampling=False, len_seq=len_seq) |
|
|
92 |
rnn_out_r = rnn_out_r.reshape(-1, rnn_out_r.size()[-1])[mask_seq] |
|
|
93 |
linear_out = linear_out.reshape(-1, linear_out.size()[-1])[mask_seq] |
|
|
94 |
|
|
|
95 |
policy_loss_3 = torch.norm(rnn_out_r)**2 + torch.norm(linear_out)**2 |
|
|
96 |
|
|
|
97 |
return policy_loss_3 |
|
|
98 |
|
|
|
99 |
#This loss encourages the minimization of the input and output weights of the RNN and the layers downstream/ |
|
|
100 |
#upstream of the RNN. |
|
|
101 |
def _policy_loss_4(self): |
|
|
102 |
|
|
|
103 |
#Find the loss encouraging the minimization of the input and output weights of the RNN and the layers downstream |
|
|
104 |
#and upstream of the RNN |
|
|
105 |
#Sample the input weights of the RNN |
|
|
106 |
J_rnn_i = self.actor.rnn.weight_ih_l0 |
|
|
107 |
J_in1 = self.actor.linear1.weight |
|
|
108 |
|
|
|
109 |
#Sample the output weights |
|
|
110 |
J_out1 = self.actor.mean_linear.weight |
|
|
111 |
J_out2 = self.actor.log_std_linear.weight |
|
|
112 |
|
|
|
113 |
policy_loss_4 = torch.norm(J_in1)**2 + torch.norm(J_rnn_i)**2 + torch.norm(J_out1)**2 + torch.norm(J_out2)**2 |
|
|
114 |
|
|
|
115 |
return policy_loss_4 |
|
|
116 |
|
|
|
117 |
#Define a loss function that constraints a subset of the RNN nodes to the experimental neural data |
|
|
118 |
def _policy_loss_exp_neural_constrain(self, policy_state_batch, h0, len_seq, neural_activity_batch, na_idx_batch, mask_seq): |
|
|
119 |
#Find the loss for neural activity constrainting |
|
|
120 |
lstm_out = self.actor.forward_lstm(policy_state_batch, h0, sampling= False, len_seq= len_seq) |
|
|
121 |
lstm_out = lstm_out.reshape(-1, lstm_out.size()[-1])[mask_seq] |
|
|
122 |
lstm_activity = lstm_out[:, 0:neural_activity_batch.shape[-1]] |
|
|
123 |
|
|
|
124 |
#Now filter the neural activity batch and lstm activity batch using the na_idx batch |
|
|
125 |
with torch.no_grad(): |
|
|
126 |
na_idx_batch = na_idx_batch.squeeze(-1) > 0 |
|
|
127 |
|
|
|
128 |
lstm_activity = lstm_activity[na_idx_batch] |
|
|
129 |
neural_activity_batch = neural_activity_batch[na_idx_batch] |
|
|
130 |
|
|
|
131 |
policy_loss_exp_c = F.mse_loss(lstm_activity, neural_activity_batch) |
|
|
132 |
|
|
|
133 |
return policy_loss_exp_c |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
def select_action(self, state: np.ndarray, h_prev: torch.Tensor, evaluate=False) -> (np.ndarray, torch.Tensor, np.ndarray): |
|
|
137 |
|
|
|
138 |
state = torch.FloatTensor(state).to(self.device).unsqueeze(0).unsqueeze(0) |
|
|
139 |
h_prev = h_prev.to(self.device) |
|
|
140 |
|
|
|
141 |
### IF TRAINING ### |
|
|
142 |
if evaluate == False: |
|
|
143 |
# get action sampled from gaussian |
|
|
144 |
action, _, _, h_current, _, rnn_out, rnn_in = self.actor.sample(state, h_prev, sampling=True, len_seq=None) |
|
|
145 |
### IF TESTING ### |
|
|
146 |
else: |
|
|
147 |
# get the action without noise |
|
|
148 |
_, _, action, h_current, _, rnn_out, rnn_in = self.actor.sample(state, h_prev, sampling=True, len_seq=None) |
|
|
149 |
|
|
|
150 |
return action.detach().cpu().numpy()[0], h_current.detach(), rnn_out.detach().cpu().numpy(), rnn_in.detach().cpu().numpy() |
|
|
151 |
|
|
|
152 |
def update_parameters(self, policy_memory: PolicyReplayMemory, policy_batch_size: int) -> (int, int, int): |
|
|
153 |
|
|
|
154 |
### SAMPLE FROM REPLAY ### |
|
|
155 |
state_batch, action_batch, reward_batch, next_state_batch, mask_batch, h_batch, policy_state_batch, neural_activity_batch, na_idx_batch = policy_memory.sample(batch_size=policy_batch_size) |
|
|
156 |
|
|
|
157 |
### CONVERT DATA TO TENSOR ### |
|
|
158 |
state_batch = torch.FloatTensor(state_batch).to(self.device) |
|
|
159 |
next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) |
|
|
160 |
action_batch = torch.FloatTensor(action_batch).to(self.device) |
|
|
161 |
reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) |
|
|
162 |
mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1) |
|
|
163 |
h_batch = torch.FloatTensor(h_batch).to(self.device).permute(1, 0, 2) |
|
|
164 |
neural_activity_batch = torch.FloatTensor(neural_activity_batch).to(self.device) |
|
|
165 |
na_idx_batch = torch.FloatTensor(na_idx_batch).to(self.device) |
|
|
166 |
|
|
|
167 |
h0 = torch.zeros(size=(1, next_state_batch.shape[0], self.hidden_size)).to(self.device) |
|
|
168 |
### SAMPLE NEXT Q VALUE FOR CRITIC LOSS ### |
|
|
169 |
with torch.no_grad(): |
|
|
170 |
next_state_action, next_state_log_pi, _, _, _, _, _ = self.actor.sample(next_state_batch.unsqueeze(1), h_batch, sampling=True) |
|
|
171 |
qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action) |
|
|
172 |
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi |
|
|
173 |
next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target) |
|
|
174 |
|
|
|
175 |
### CALCULATE CRITIC LOSS ### |
|
|
176 |
qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step |
|
|
177 |
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] |
|
|
178 |
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2] |
|
|
179 |
qf_loss = qf1_loss + qf2_loss |
|
|
180 |
|
|
|
181 |
### TAKE GRAIDENT STEP ### |
|
|
182 |
self.critic_optim.zero_grad() |
|
|
183 |
qf_loss.backward() |
|
|
184 |
self.critic_optim.step() |
|
|
185 |
|
|
|
186 |
### SAMPLE FROM ACTOR NETWORK ### |
|
|
187 |
h0 = torch.zeros(size=(1, len(policy_state_batch), self.hidden_size)).to(self.device) |
|
|
188 |
len_seq = list(map(len, policy_state_batch)) |
|
|
189 |
policy_state_batch = torch.FloatTensor(pad_sequence(policy_state_batch, batch_first= True)).to(self.device) |
|
|
190 |
pi_action_bat, log_prob_bat, _, _, mask_seq, _, _ = self.actor.sample(policy_state_batch, h0, sampling=False, len_seq=len_seq) |
|
|
191 |
|
|
|
192 |
### MASK POLICY STATE BATCH ### |
|
|
193 |
policy_state_batch_pi = policy_state_batch.reshape(-1, policy_state_batch.size()[-1])[mask_seq] |
|
|
194 |
|
|
|
195 |
### GET VALUE OF CURRENT STATE AND ACTION PAIRS ### |
|
|
196 |
qf1_pi, qf2_pi = self.critic(policy_state_batch_pi, pi_action_bat) |
|
|
197 |
min_qf_pi = torch.min(qf1_pi, qf2_pi) |
|
|
198 |
|
|
|
199 |
### CALCULATE POLICY LOSS ### |
|
|
200 |
task_loss = ((self.alpha * log_prob_bat) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))] |
|
|
201 |
policy_loss = task_loss |
|
|
202 |
|
|
|
203 |
############################ |
|
|
204 |
# ADDITIONAL POLICY LOSSES # |
|
|
205 |
############################ |
|
|
206 |
|
|
|
207 |
if self.multi_policy_loss: |
|
|
208 |
|
|
|
209 |
loss_simple_dynamics = self._policy_loss_2(policy_state_batch, h0, len_seq, mask_seq) |
|
|
210 |
loss_activations_min = self._policy_loss_3(policy_state_batch, h0, len_seq, mask_seq) |
|
|
211 |
loss_weights_min = self._policy_loss_4() |
|
|
212 |
loss_exp_constrain = self._policy_loss_exp_neural_constrain(policy_state_batch, h0, len_seq, neural_activity_batch, na_idx_batch, mask_seq) |
|
|
213 |
|
|
|
214 |
### CALCULATE FINAL POLICY LOSS ### |
|
|
215 |
#To implement nuSim training use a weighting of 1e+04 with loss_exp_constrain |
|
|
216 |
policy_loss += (self.alpha_usim*(loss_simple_dynamics)) \ |
|
|
217 |
+ (self.beta_usim*(loss_activations_min)) \ |
|
|
218 |
+ (self.gamma_usim*(loss_weights_min)) \ |
|
|
219 |
+ (self.zeta_nusim*(loss_exp_constrain)) |
|
|
220 |
|
|
|
221 |
### TAKE GRADIENT STEP ### |
|
|
222 |
self.actor_optim.zero_grad() |
|
|
223 |
policy_loss.backward() |
|
|
224 |
self.actor_optim.step() |
|
|
225 |
|
|
|
226 |
### AUTOMATIC ENTROPY TUNING ### |
|
|
227 |
log_pi = log_prob_bat |
|
|
228 |
if self.automatic_entropy_tuning: |
|
|
229 |
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() |
|
|
230 |
|
|
|
231 |
self.alpha_optim.zero_grad() |
|
|
232 |
alpha_loss.backward() |
|
|
233 |
self.alpha_optim.step() |
|
|
234 |
|
|
|
235 |
self.alpha = self.log_alpha.exp() |
|
|
236 |
|
|
|
237 |
### SOFT UPDATE OF CRITIC TARGET ### |
|
|
238 |
soft_update(self.critic_target, self.critic, self.tau) |
|
|
239 |
|
|
|
240 |
return qf1_loss.item(), qf2_loss.item(), policy_loss.item() |