Diff of /SAC/model.py [000000] .. [9f010e]

Switch to unified view

a b/SAC/model.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from torch.distributions import Normal
5
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6
7
LOG_SIG_MAX = 2
8
LOG_SIG_MIN = -20
9
epsilon = 1e-6
10
11
# Initialize Policy weights
12
def weights_init_(m):
13
    if isinstance(m, nn.Linear):
14
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
15
        torch.nn.init.constant_(m.bias, 0)
16
17
class Actor(nn.Module):
18
    def __init__(self, num_inputs, num_actions, hidden_dim, model, action_space=None):
19
        super(Actor, self).__init__()
20
21
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
22
23
        if model == "rnn":
24
            self.rnn = nn.RNN(hidden_dim, hidden_dim, batch_first=True)
25
        elif model == "gru":
26
            self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
27
        else:
28
            raise NotImplementedError
29
30
        self.mean_linear = nn.Linear(hidden_dim, num_actions)
31
        self.log_std_linear = nn.Linear(hidden_dim, num_actions)
32
33
        self.apply(weights_init_)
34
35
        # action rescaling
36
        # Pass none action space and adjust the action scale and bias manually
37
        if action_space is None:
38
            self.action_scale = torch.tensor(0.5)
39
            self.action_bias = torch.tensor(0.5)
40
        else:
41
            self.action_scale = torch.FloatTensor(
42
                (action_space.high - action_space.low) / 2.)
43
            self.action_bias = torch.FloatTensor(
44
                (action_space.high + action_space.low) / 2.)
45
46
    def forward(self, state, h_prev, sampling, len_seq= None):
47
48
        x = F.tanh(self.linear1(state))
49
50
        if sampling == False:
51
            assert len_seq!=None, "Proved the len_seq"
52
            x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False)
53
54
        #Tap RNN input for fixedpoint analysis
55
        rnn_in = x
56
57
        x, (h_current) = self.rnn(x, (h_prev))
58
59
        if sampling == False:
60
           x, _ = pad_packed_sequence(x, batch_first= True)
61
62
        if sampling == True:
63
            x = x.squeeze(1)
64
65
        # x = F.relu(x)
66
        mean = self.mean_linear(x)
67
        log_std = self.log_std_linear(x)
68
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
69
70
        return mean, log_std, h_current, x, rnn_in
71
72
    def sample(self, state, h_prev, sampling, len_seq=None):
73
74
        mean, log_std, h_current, x, rnn_in = self.forward(state, h_prev, sampling, len_seq)
75
        #if sampling == False; then mask the mean and log_std using len_seq
76
        if sampling == False:
77
            assert mean.size()[1] == log_std.size()[1], "There is a mismatch between and mean and sigma Sl_max"
78
            sl_max = mean.size()[1]
79
            with torch.no_grad():
80
                for seq_idx, k in enumerate(len_seq):
81
                    for j in range(1, sl_max + 1):
82
                        if j <= k:
83
                            if seq_idx == 0 and j == 1:
84
                                mask_seq = torch.tensor([True], dtype=bool)
85
                            else:
86
                                mask_seq = torch.cat((mask_seq, torch.tensor([True])), dim=0)
87
                        else:
88
                            mask_seq = torch.cat((mask_seq, torch.tensor([False])), dim=0)
89
            #The mask has been created, Now filter the mean and sigma using this mask
90
            mean = mean.reshape(-1, mean.size()[-1])[mask_seq]
91
            log_std = log_std.reshape(-1, log_std.size()[-1])[mask_seq]
92
93
        if sampling == True:
94
            mask_seq = [] #If sampling is True return a dummy mask seq
95
96
        std = log_std.exp()
97
98
        # white noise
99
        normal = Normal(mean, std)
100
        noise = normal.rsample()
101
102
        # reparameterization trick
103
        y_t = torch.tanh(noise) 
104
        action = y_t * self.action_scale + self.action_bias
105
        log_prob = normal.log_prob(noise)
106
107
        # Enforce the action_bounds
108
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
109
        log_prob = log_prob.sum(1, keepdim=True)
110
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
111
112
        return action, log_prob, mean, h_current, mask_seq, x, rnn_in
113
114
    def forward_for_simple_dynamics(self, state, h_prev, sampling, len_seq= None):
115
116
        x = F.tanh(self.linear1(state))
117
118
        #Tap the output of the first linear layer
119
        x_l1 = x
120
121
        if sampling == False:
122
            assert len_seq!=None, "Proved the len_seq"
123
            x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False)
124
125
        x, _ = self.rnn(x, (h_prev))
126
127
        if sampling == False:
128
           x, _ = pad_packed_sequence(x, batch_first= True)
129
130
        # x = F.relu(x)
131
132
        return x, x_l1
133
134
135
    def forward_lstm(self, state, h_prev, sampling, len_seq= None):
136
137
        x = F.tanh(self.linear1(state))
138
139
        if sampling == False:
140
            assert len_seq!=None, "Proved the len_seq"
141
142
            x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False)
143
144
        x, (h_current) = self.rnn(x, (h_prev))
145
146
        if sampling == False:
147
           x, len_x_seq = pad_packed_sequence(x, batch_first= True)
148
149
        if sampling == True:
150
            x = x.squeeze(1)
151
152
        return x
153
154
    def forward_for_neural_pert(self, state, h_prev, neural_pert= None):
155
156
        x = F.tanh(self.linear1(state))
157
158
        #Tap RNN input for fixedpoint analysis
159
        rnn_in = x
160
161
        x, (h_current) = self.rnn(x, (h_prev))
162
163
        #Add the neural perturbation to the RNN output
164
        x = x+neural_pert
165
166
        x = x.squeeze(1)
167
168
        mean = self.mean_linear(x)
169
        log_std = self.log_std_linear(x)
170
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
171
172
173
        action = torch.tanh(mean) * self.action_scale + self.action_bias
174
175
        return action.detach().cpu().numpy()[0], h_current.detach(), x.detach().cpu().numpy(), rnn_in.detach().cpu().numpy()
176
177
class Critic(nn.Module):
178
    def __init__(self, num_inputs, num_actions, hidden_dim):
179
        super(Critic, self).__init__()
180
181
        # Q1 architecture
182
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
183
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
184
        self.linear3 = nn.Linear(hidden_dim, 1)
185
186
        # Q2 architecture
187
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
188
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
189
        self.linear6 = nn.Linear(hidden_dim, 1)
190
191
        self.apply(weights_init_)
192
193
    def forward(self, state, action):
194
195
        xu = torch.cat([state, action], 1)
196
        
197
        x1 = F.relu(self.linear1(xu))
198
        x1 = F.relu(self.linear2(x1))
199
        x1 = self.linear3(x1)
200
201
        x2 = F.relu(self.linear4(xu))
202
        x2 = F.relu(self.linear5(x2))
203
        x2 = self.linear6(x2)
204
205
        return x1, x2