Diff of /SAC/sac.py [000000] .. [9f010e]

Switch to unified view

a b/SAC/sac.py
1
import os
2
import torch
3
import torch.nn.functional as F
4
from torch.optim import Adam
5
from .utils import soft_update, hard_update
6
from .model import Actor, Critic
7
from torch.nn.utils.rnn import pad_sequence
8
import numpy as np
9
from .replay_memory import PolicyReplayMemory
10
import ipdb
11
12
class SAC_Agent():
13
    def __init__(self, 
14
                 num_inputs: int, 
15
                 action_space: int, 
16
                 hidden_size: int, 
17
                 lr: float, 
18
                 gamma: float, 
19
                 tau: float, 
20
                 alpha: float, 
21
                 automatic_entropy_tuning: bool, 
22
                 model: str, 
23
                 multi_policy_loss: bool,
24
                 alpha_usim:float,
25
                 beta_usim:float,
26
                 gamma_usim:float,
27
                 zeta_nusim:float,
28
                 cuda: bool):
29
30
        if cuda:
31
            self.device = torch.device("cuda")
32
        else:
33
            self.device = torch.device("cpu")
34
35
        #Save the regularization parameters
36
        self.alpha_usim = alpha_usim
37
        self.beta_usim = beta_usim
38
        self.gamma_usim = gamma_usim
39
        self.zeta_nusim = zeta_nusim
40
41
        ### SET CRITIC NETWORKS ###
42
        self.critic = Critic(num_inputs, action_space.shape[0], hidden_size).to(self.device)
43
        self.critic_target = Critic(num_inputs, action_space.shape[0], hidden_size).to(self.device)
44
        self.critic_optim = Adam(self.critic.parameters(), lr=lr)
45
        hard_update(self.critic_target, self.critic)
46
47
        ### SET ACTOR NETWORK ###
48
        self.actor = Actor(num_inputs, action_space.shape[0], hidden_size, model, action_space=None).to(self.device)
49
        self.actor_optim = Adam(self.actor.parameters(), lr=lr)
50
51
        ### SET TRAINING VARIABLES ###
52
        self.model = model
53
        self.multi_policy_loss = multi_policy_loss
54
        self.gamma = gamma
55
        self.tau = tau
56
        self.alpha = alpha
57
        self.hidden_size= hidden_size
58
        self.automatic_entropy_tuning = automatic_entropy_tuning
59
60
        # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
61
        if automatic_entropy_tuning:
62
            self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
63
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
64
            self.alpha_optim = Adam([self.log_alpha], lr=lr)
65
    
66
    #This loss encourages the simple low-dimensional dynamics in the RNN activity
67
    def _policy_loss_2(self, policy_state_batch, h0, len_seq, mask_seq):
68
69
        # Sample the hidden weights of the RNN
70
        J_rnn_w = self.actor.rnn.weight_hh_l0        #These weights would be of the size (hidden_dim, hidden_dim)
71
72
        #Sample the output of the RNN for the policy_state_batch
73
        rnn_out_r, _ = self.actor.forward_for_simple_dynamics(policy_state_batch, h0, sampling=False, len_seq=len_seq)
74
        rnn_out_r = rnn_out_r.reshape(-1, rnn_out_r.size()[-1])[mask_seq]
75
76
        #Reshape the policy hidden weights vector
77
        J_rnn_w = J_rnn_w.unsqueeze(0).repeat(rnn_out_r.size()[0], 1, 1)
78
        rnn_out_r = 1 - torch.pow(rnn_out_r, 2)
79
80
        R_j = torch.mul(J_rnn_w, rnn_out_r.unsqueeze(-1))
81
82
        policy_loss_2 = torch.norm(R_j)**2
83
84
        return policy_loss_2
85
    
86
    #This loss encourages the minimization of the firing rates for the linear and the RNN layer.
87
    def _policy_loss_3(self, policy_state_batch, h0, len_seq, mask_seq):
88
89
        #Find the loss encouraging the minimization of the firing rates for the linear and the RNN layer
90
        #Sample the output of the RNN for the policy_state_batch
91
        rnn_out_r, linear_out = self.actor.forward_for_simple_dynamics(policy_state_batch, h0, sampling=False, len_seq=len_seq)
92
        rnn_out_r = rnn_out_r.reshape(-1, rnn_out_r.size()[-1])[mask_seq]
93
        linear_out = linear_out.reshape(-1, linear_out.size()[-1])[mask_seq]
94
95
        policy_loss_3 = torch.norm(rnn_out_r)**2 + torch.norm(linear_out)**2
96
97
        return policy_loss_3
98
99
    #This loss encourages the minimization of the input and output weights of the RNN and the layers downstream/
100
    #upstream of the RNN. 
101
    def _policy_loss_4(self):
102
103
        #Find the loss encouraging the minimization of the input and output weights of the RNN and the layers downstream
104
        #and upstream of the RNN
105
        #Sample the input weights of the RNN
106
        J_rnn_i = self.actor.rnn.weight_ih_l0
107
        J_in1 = self.actor.linear1.weight
108
109
        #Sample the output weights
110
        J_out1 = self.actor.mean_linear.weight
111
        J_out2 = self.actor.log_std_linear.weight
112
113
        policy_loss_4 = torch.norm(J_in1)**2 + torch.norm(J_rnn_i)**2 + torch.norm(J_out1)**2 + torch.norm(J_out2)**2
114
115
        return policy_loss_4
116
117
    #Define a loss function that constraints a subset of the RNN nodes to the experimental neural data
118
    def _policy_loss_exp_neural_constrain(self, policy_state_batch, h0, len_seq, neural_activity_batch, na_idx_batch, mask_seq):
119
        #Find the loss for neural activity constrainting
120
        lstm_out = self.actor.forward_lstm(policy_state_batch, h0, sampling= False, len_seq= len_seq)
121
        lstm_out = lstm_out.reshape(-1, lstm_out.size()[-1])[mask_seq]
122
        lstm_activity = lstm_out[:, 0:neural_activity_batch.shape[-1]]
123
124
        #Now filter the neural activity batch and lstm activity batch using the na_idx batch
125
        with torch.no_grad():
126
            na_idx_batch = na_idx_batch.squeeze(-1) > 0
127
        
128
        lstm_activity = lstm_activity[na_idx_batch]
129
        neural_activity_batch = neural_activity_batch[na_idx_batch]
130
131
        policy_loss_exp_c = F.mse_loss(lstm_activity, neural_activity_batch)
132
133
        return policy_loss_exp_c
134
135
136
    def select_action(self, state: np.ndarray, h_prev: torch.Tensor, evaluate=False) -> (np.ndarray, torch.Tensor, np.ndarray):
137
138
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0).unsqueeze(0)
139
        h_prev = h_prev.to(self.device)
140
141
        ### IF TRAINING ###
142
        if evaluate == False: 
143
            # get action sampled from gaussian
144
            action, _, _, h_current, _, rnn_out, rnn_in = self.actor.sample(state, h_prev, sampling=True, len_seq=None)
145
        ### IF TESTING ###
146
        else:
147
            # get the action without noise
148
            _, _, action, h_current, _, rnn_out, rnn_in = self.actor.sample(state, h_prev, sampling=True, len_seq=None)
149
150
        return action.detach().cpu().numpy()[0], h_current.detach(), rnn_out.detach().cpu().numpy(), rnn_in.detach().cpu().numpy()
151
152
    def update_parameters(self, policy_memory: PolicyReplayMemory, policy_batch_size: int) -> (int, int, int):
153
154
        ### SAMPLE FROM REPLAY ###
155
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch, h_batch, policy_state_batch, neural_activity_batch, na_idx_batch = policy_memory.sample(batch_size=policy_batch_size)
156
157
        ### CONVERT DATA TO TENSOR ###
158
        state_batch = torch.FloatTensor(state_batch).to(self.device)
159
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
160
        action_batch = torch.FloatTensor(action_batch).to(self.device)
161
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
162
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
163
        h_batch = torch.FloatTensor(h_batch).to(self.device).permute(1, 0, 2)
164
        neural_activity_batch = torch.FloatTensor(neural_activity_batch).to(self.device)
165
        na_idx_batch = torch.FloatTensor(na_idx_batch).to(self.device)
166
167
        h0 = torch.zeros(size=(1, next_state_batch.shape[0], self.hidden_size)).to(self.device)
168
        ### SAMPLE NEXT Q VALUE FOR CRITIC LOSS ###
169
        with torch.no_grad():
170
            next_state_action, next_state_log_pi, _, _, _, _, _ = self.actor.sample(next_state_batch.unsqueeze(1), h_batch, sampling=True)
171
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
172
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
173
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
174
175
        ### CALCULATE CRITIC LOSS ###
176
        qf1, qf2 = self.critic(state_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
177
        qf1_loss = F.mse_loss(qf1, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
178
        qf2_loss = F.mse_loss(qf2, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
179
        qf_loss = qf1_loss + qf2_loss
180
181
        ### TAKE GRAIDENT STEP ###
182
        self.critic_optim.zero_grad()
183
        qf_loss.backward()
184
        self.critic_optim.step()
185
186
        ### SAMPLE FROM ACTOR NETWORK ###
187
        h0 = torch.zeros(size=(1, len(policy_state_batch), self.hidden_size)).to(self.device)
188
        len_seq = list(map(len, policy_state_batch))
189
        policy_state_batch = torch.FloatTensor(pad_sequence(policy_state_batch, batch_first= True)).to(self.device)
190
        pi_action_bat, log_prob_bat, _, _, mask_seq, _, _  = self.actor.sample(policy_state_batch, h0, sampling=False, len_seq=len_seq)
191
192
        ### MASK POLICY STATE BATCH ###
193
        policy_state_batch_pi = policy_state_batch.reshape(-1, policy_state_batch.size()[-1])[mask_seq]
194
195
        ### GET VALUE OF CURRENT STATE AND ACTION PAIRS ###
196
        qf1_pi, qf2_pi = self.critic(policy_state_batch_pi, pi_action_bat)
197
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
198
199
        ### CALCULATE POLICY LOSS ###
200
        task_loss = ((self.alpha * log_prob_bat) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
201
        policy_loss = task_loss
202
203
        ############################
204
        # ADDITIONAL POLICY LOSSES #
205
        ############################
206
207
        if self.multi_policy_loss:
208
209
            loss_simple_dynamics = self._policy_loss_2(policy_state_batch, h0, len_seq, mask_seq)
210
            loss_activations_min = self._policy_loss_3(policy_state_batch, h0, len_seq, mask_seq)
211
            loss_weights_min = self._policy_loss_4()
212
            loss_exp_constrain = self._policy_loss_exp_neural_constrain(policy_state_batch, h0, len_seq, neural_activity_batch, na_idx_batch, mask_seq)
213
214
            ### CALCULATE FINAL POLICY LOSS ###
215
            #To implement nuSim training use a weighting of 1e+04 with loss_exp_constrain
216
            policy_loss += (self.alpha_usim*(loss_simple_dynamics))  \
217
                            + (self.beta_usim*(loss_activations_min))  \
218
                            + (self.gamma_usim*(loss_weights_min))  \
219
                            + (self.zeta_nusim*(loss_exp_constrain))
220
221
        ### TAKE GRADIENT STEP ###
222
        self.actor_optim.zero_grad()
223
        policy_loss.backward()
224
        self.actor_optim.step()
225
226
        ### AUTOMATIC ENTROPY TUNING ###
227
        log_pi = log_prob_bat
228
        if self.automatic_entropy_tuning:
229
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
230
231
            self.alpha_optim.zero_grad()
232
            alpha_loss.backward()
233
            self.alpha_optim.step()
234
235
            self.alpha = self.log_alpha.exp()
236
237
        ### SOFT UPDATE OF CRITIC TARGET ###
238
        soft_update(self.critic_target, self.critic, self.tau)
239
240
        return qf1_loss.item(), qf2_loss.item(), policy_loss.item()