Switch to side-by-side view

--- a
+++ b/mouse_scripts/SAC/model.py
@@ -0,0 +1,330 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributions import Normal
+from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
+import numpy as np
+import colorednoise as cn
+import math
+
+LOG_SIG_MAX = 2
+LOG_SIG_MIN = -20
+epsilon = 1e-6
+
+# Initialize Policy weights
+def weights_init_(m):
+    if isinstance(m, nn.Linear):
+        torch.nn.init.xavier_uniform_(m.weight, gain=1)
+        torch.nn.init.constant_(m.bias, 0)
+
+class ValueNetwork(nn.Module):
+    def __init__(self, num_inputs, hidden_dim):
+        super(ValueNetwork, self).__init__()
+
+        self.linear1 = nn.Linear(num_inputs, hidden_dim)
+        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
+        self.linear3 = nn.Linear(hidden_dim, 1)
+
+        self.apply(weights_init_)
+
+    def forward(self, state):
+        x = F.relu(self.linear1(state))
+        x = F.relu(self.linear2(x))
+        x = self.linear3(x)
+        return x
+
+class QNetworkFF(nn.Module):
+    def __init__(self, num_inputs, num_actions, hidden_dim):
+        super(QNetworkFF, self).__init__()
+
+        # Q1 architecture
+        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
+        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
+        self.linear3 = nn.Linear(hidden_dim, 1)
+
+        # Q2 architecture
+        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
+        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
+        self.linear6 = nn.Linear(hidden_dim, 1)
+
+        self.apply(weights_init_)
+
+    def forward(self, state, action):
+
+        xu = torch.cat([state, action], 1)
+        
+        x1 = F.relu(self.linear1(xu))
+        x1 = F.relu(self.linear2(x1))
+        x1 = self.linear3(x1)
+
+        x2 = F.relu(self.linear4(xu))
+        x2 = F.relu(self.linear5(x2))
+        x2 = self.linear6(x2)
+
+        return x1, x2
+
+class QNetworkLSTM(nn.Module):
+    def __init__(self, num_inputs, num_actions, hidden_dim):
+        super(QNetworkLSTM, self).__init__()
+
+        # Q1 architecture
+        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
+        nn.init.xavier_normal_(self.linear1.weight)
+        self.linear2 = nn.Linear(num_inputs + num_actions, hidden_dim)
+        nn.init.xavier_normal_(self.linear2.weight)
+        self.lstm1 = nn.LSTM(hidden_dim, hidden_dim, num_layers= 1, batch_first= True)
+        self.linear3 = nn.Linear(2 * hidden_dim, hidden_dim)
+        nn.init.xavier_normal_(self.linear3.weight)
+        self.linear4 = nn.Linear(hidden_dim, 1)
+        nn.init.xavier_normal_(self.linear4.weight)
+
+        # Q2 architecture
+        self.linear5 = nn.Linear(num_inputs + num_actions, hidden_dim)
+        nn.init.xavier_normal_(self.linear5.weight)
+        self.linear6 = nn.Linear(num_inputs + num_actions, hidden_dim)
+        nn.init.xavier_normal_(self.linear6.weight)
+        self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, num_layers= 1, batch_first= True)
+        self.linear7 = nn.Linear(2 * hidden_dim, hidden_dim)
+        nn.init.xavier_normal_(self.linear7.weight)
+        self.linear8 = nn.Linear(hidden_dim, 1)
+        nn.init.xavier_normal_(self.linear8.weight)
+
+        self.apply(weights_init_)
+        # notes: weights_init for the LSTM layer
+
+    def forward(self, state_action_packed, hidden):
+
+        xu = state_action_packed
+        xu_p, seq_lens = pad_packed_sequence(xu, batch_first= True)
+
+        fc_branch_1 = F.relu(self.linear1(xu_p))
+
+        lstm_branch_1 = F.relu(self.linear2(xu_p))
+        lstm_branch_1 = pack_padded_sequence(lstm_branch_1, seq_lens, batch_first= True, enforce_sorted= False)
+        lstm_branch_1, hidden_out_1 = self.lstm1(lstm_branch_1, hidden)
+        lstm_branch_1, _ = pad_packed_sequence(lstm_branch_1, batch_first= True)
+
+        x1 = torch.cat([fc_branch_1, lstm_branch_1], dim=-1)
+        x1 = F.relu(self.linear3(x1))
+        x1 = F.relu(self.linear4(x1))
+
+        fc_branch_2 = F.relu(self.linear5(xu_p))
+
+        lstm_branch_2 = F.relu(self.linear6(xu_p))
+        lstm_branch_2 = pack_padded_sequence(lstm_branch_2, seq_lens, batch_first= True, enforce_sorted= False)
+        lstm_branch_2, hidden_out_2 = self.lstm2(lstm_branch_2, hidden)
+        lstm_branch_2, _ = pad_packed_sequence(lstm_branch_2, batch_first= True)
+
+        x2 = torch.cat([fc_branch_2, lstm_branch_2], dim=-1)
+        x2 = F.relu(self.linear7(x2))
+        x2 = F.relu(self.linear8(x2))
+
+        return x1, x2
+
+class GaussianPolicyRNN(nn.Module):
+    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
+        super(GaussianPolicyRNN, self).__init__()
+
+        self.linear1 = nn.Linear(num_inputs, hidden_dim)
+        self.lstm = nn.RNN(hidden_dim, hidden_dim, batch_first=True)
+
+        self.mean_linear = nn.Linear(hidden_dim, num_actions)
+        self.log_std_linear = nn.Linear(hidden_dim, num_actions)
+
+        self.apply(weights_init_)
+
+        # action rescaling
+        # Pass none action space and adjust the action scale and bias manually
+        if action_space is None:
+            self.action_scale = torch.tensor(0.5)
+            self.action_bias = torch.tensor(0.5)
+        else:
+            self.action_scale = torch.FloatTensor(
+                (action_space.high - action_space.low) / 2.)
+            self.action_bias = torch.FloatTensor(
+                (action_space.high + action_space.low) / 2.)
+
+    def forward(self, state, h_prev, c_prev, sampling, len_seq= None):
+
+        #x = F.relu(F.tanh(self.linear1(state)))
+        #x = F.tanh(self.linear1(state))
+        x = F.relu(self.linear1(state))
+
+        if sampling == False:
+            assert len_seq!=None, "Proved the len_seq"
+            x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False)
+
+        x, (h_current) = self.lstm(x, (h_prev))
+
+        if sampling == False:
+           x, len_x_seq = pad_packed_sequence(x, batch_first= True)
+
+        if sampling == True:
+            x = x.squeeze(1)
+
+        x = F.relu(x)
+        mean = self.mean_linear(x)
+        log_std = self.log_std_linear(x)
+        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
+
+        c_current= torch.tensor(0., requires_grad= True)
+        return mean, log_std, h_current, c_current, x
+
+    def sample(self, state, h_prev, c_prev, sampling, len_seq=None):
+
+        mean, log_std, h_current, c_current, x = self.forward(state, h_prev, c_prev, sampling, len_seq)
+        #if sampling == False; then mask the mean and log_std using len_seq
+        if sampling == False:
+            assert mean.size()[1] == log_std.size()[1], "There is a mismatch between and mean and sigma Sl_max"
+            sl_max = mean.size()[1]
+            with torch.no_grad():
+                for seq_idx, k in enumerate(len_seq):
+                    for j in range(1, sl_max + 1):
+                        if j <= k:
+                            if seq_idx == 0 and j == 1:
+                                mask_seq = torch.tensor([True], dtype=bool)
+                            else:
+                                mask_seq = torch.cat((mask_seq, torch.tensor([True])), dim=0)
+                        else:
+                            mask_seq = torch.cat((mask_seq, torch.tensor([False])), dim=0)
+        #The mask has been created, Now filter the mean and sigma using this mask
+            print(mask_seq)
+            mean = mean.reshape(-1, mean.size()[-1])[mask_seq]
+            log_std = log_std.reshape(-1, log_std.size()[-1])[mask_seq]
+        if sampling == True:
+            mask_seq = [] #If sampling is True return a dummy mask seq
+
+        std = log_std.exp()
+
+        # white noise
+        normal = Normal(mean, std)
+        noise = normal.rsample()
+
+        # pink noise
+        #samples = math.prod(mean.squeeze().shape)
+        #noise = cn.powerlaw_psd_gaussian(1, samples)
+        #noise = torch.Tensor(noise).view(mean.shape).to(mean.device)
+
+        y_t = torch.tanh(noise) # reparameterization trick
+        action = y_t * self.action_scale + self.action_bias
+        log_prob = normal.log_prob(noise)
+        # Enforce the action_bounds
+        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
+        log_prob = log_prob.sum(1, keepdim=True)
+        mean = torch.tanh(mean) * self.action_scale + self.action_bias
+
+        return action, log_prob, mean, h_current, c_current, mask_seq, x
+
+    def forward_for_simple_dynamics(self, state, h_prev, c_prev, sampling, len_seq= None):
+
+        #x = F.relu(F.tanh(self.linear1(state)))
+        #x = F.tanh(self.linear1(state))
+        x = F.relu(self.linear1(state))
+
+        #Tap the output of the first linear layer
+        x_l1 = x
+        # x = state
+
+        if sampling == False:
+            assert len_seq!=None, "Proved the len_seq"
+
+            x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False)
+
+        x, (h_current) = self.lstm(x, (h_prev))
+
+        if sampling == False:
+           x, len_x_seq = pad_packed_sequence(x, batch_first= True)
+
+        x = F.relu(x)
+        return x, x_l1
+
+    def to(self, device):
+        self.action_scale = self.action_scale.to(device)
+        self.action_bias = self.action_bias.to(device)
+        return super(GaussianPolicyRNN, self).to(device)
+
+class GaussianPolicyLSTM(nn.Module):
+    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
+        super(GaussianPolicyLSTM, self).__init__()
+
+        self.linear1 = nn.Linear(num_inputs, hidden_dim)
+        nn.init.xavier_normal_(self.linear1.weight)
+        self.lstm = nn.LSTM(num_inputs, hidden_dim, num_layers=1, batch_first=True)
+        self.linear2 = nn.Linear(2*hidden_dim, hidden_dim)
+        nn.init.xavier_normal_(self.linear2.weight)
+
+        self.mean_linear = nn.Linear(hidden_dim, num_actions)
+        self.log_std_linear = nn.Linear(hidden_dim, num_actions)
+
+        self.apply(weights_init_)
+        # Adjust the initial weights of the recurrent LSTM layer
+
+        # action rescaling
+        # Pass none action space and adjust the action scale and bias manually
+        if action_space is None:
+            # Try different scales to see what works best
+            self.action_scale = torch.tensor(0.5)
+            self.action_bias = torch.tensor(0.5)
+        else:
+            self.action_scale = torch.FloatTensor(
+                (action_space.high - action_space.low) / 2.)
+            self.action_bias = torch.FloatTensor(
+                (action_space.high + action_space.low) / 2.)
+
+    def forward(self, state, h_prev, c_prev, sampling):
+
+        if sampling == True:
+            fc_branch = F.relu(self.linear1(state))
+            lstm_branch, (h_current, c_current) = self.lstm(state, (h_prev, c_prev))
+        else:
+            state_pad, _ = pad_packed_sequence(state, batch_first= True)
+            fc_branch = F.relu(self.linear1(state_pad))
+            lstm_branch, (h_current, c_current) = self.lstm(state, (h_prev, c_prev))
+            lstm_branch, seq_lens = pad_packed_sequence(lstm_branch, batch_first= True)
+
+        x = torch.cat([fc_branch, lstm_branch], dim=-1)
+        x = F.relu(self.linear2(x))
+        
+        mean = self.mean_linear(x)
+        log_std = self.log_std_linear(x)
+        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
+
+        return mean, log_std, h_current, c_current, lstm_branch
+
+    def sample(self, state, h_prev, c_prev, sampling):
+
+        mean, log_std, h_current, c_current, lstm_branch = self.forward(state, h_prev, c_prev, sampling)
+        #if sampling == False; then reshape mean and log_std from (B, L_max, A) to (B*Lmax, A)
+
+        mean_size = mean.size()
+        log_std_size = log_std.size()
+
+        mean = mean.reshape(-1, mean.size()[-1])
+        log_std = log_std.reshape(-1, log_std.size()[-1])
+
+        std = log_std.exp()
+        normal = Normal(mean, std)
+        x_t = normal.rsample()
+
+        y_t = torch.tanh(x_t)
+        action = y_t * self.action_scale + self.action_bias
+
+        log_prob = normal.log_prob(x_t)
+
+        # Enforce the action_bounds
+        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
+        log_prob = log_prob.sum(1, keepdim=True)
+
+        mean = torch.tanh(mean) * self.action_scale + self.action_bias
+
+        if sampling == False:
+            action = action.reshape(mean_size[0], mean_size[1], mean_size[2])
+            mean = mean.reshape(mean_size[0], mean_size[1], mean_size[2])
+            log_prob = log_prob.reshape(log_std_size[0], log_std_size[1], 1) 
+
+        return action, log_prob, mean, h_current, c_current, lstm_branch
+    
+    def to(self, device):
+        self.action_scale = self.action_scale.to(device)
+        self.action_bias = self.action_bias.to(device)
+        return super(GaussianPolicyLSTM, self).to(device)