|
a |
|
b/SAC/model.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
from torch.distributions import Normal |
|
|
5 |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
|
6 |
|
|
|
7 |
LOG_SIG_MAX = 2 |
|
|
8 |
LOG_SIG_MIN = -20 |
|
|
9 |
epsilon = 1e-6 |
|
|
10 |
|
|
|
11 |
# Initialize Policy weights |
|
|
12 |
def weights_init_(m): |
|
|
13 |
if isinstance(m, nn.Linear): |
|
|
14 |
torch.nn.init.xavier_uniform_(m.weight, gain=1) |
|
|
15 |
torch.nn.init.constant_(m.bias, 0) |
|
|
16 |
|
|
|
17 |
class Actor(nn.Module): |
|
|
18 |
def __init__(self, num_inputs, num_actions, hidden_dim, model, action_space=None): |
|
|
19 |
super(Actor, self).__init__() |
|
|
20 |
|
|
|
21 |
self.linear1 = nn.Linear(num_inputs, hidden_dim) |
|
|
22 |
|
|
|
23 |
if model == "rnn": |
|
|
24 |
self.rnn = nn.RNN(hidden_dim, hidden_dim, batch_first=True) |
|
|
25 |
elif model == "gru": |
|
|
26 |
self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True) |
|
|
27 |
else: |
|
|
28 |
raise NotImplementedError |
|
|
29 |
|
|
|
30 |
self.mean_linear = nn.Linear(hidden_dim, num_actions) |
|
|
31 |
self.log_std_linear = nn.Linear(hidden_dim, num_actions) |
|
|
32 |
|
|
|
33 |
self.apply(weights_init_) |
|
|
34 |
|
|
|
35 |
# action rescaling |
|
|
36 |
# Pass none action space and adjust the action scale and bias manually |
|
|
37 |
if action_space is None: |
|
|
38 |
self.action_scale = torch.tensor(0.5) |
|
|
39 |
self.action_bias = torch.tensor(0.5) |
|
|
40 |
else: |
|
|
41 |
self.action_scale = torch.FloatTensor( |
|
|
42 |
(action_space.high - action_space.low) / 2.) |
|
|
43 |
self.action_bias = torch.FloatTensor( |
|
|
44 |
(action_space.high + action_space.low) / 2.) |
|
|
45 |
|
|
|
46 |
def forward(self, state, h_prev, sampling, len_seq= None): |
|
|
47 |
|
|
|
48 |
x = F.tanh(self.linear1(state)) |
|
|
49 |
|
|
|
50 |
if sampling == False: |
|
|
51 |
assert len_seq!=None, "Proved the len_seq" |
|
|
52 |
x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False) |
|
|
53 |
|
|
|
54 |
#Tap RNN input for fixedpoint analysis |
|
|
55 |
rnn_in = x |
|
|
56 |
|
|
|
57 |
x, (h_current) = self.rnn(x, (h_prev)) |
|
|
58 |
|
|
|
59 |
if sampling == False: |
|
|
60 |
x, _ = pad_packed_sequence(x, batch_first= True) |
|
|
61 |
|
|
|
62 |
if sampling == True: |
|
|
63 |
x = x.squeeze(1) |
|
|
64 |
|
|
|
65 |
# x = F.relu(x) |
|
|
66 |
mean = self.mean_linear(x) |
|
|
67 |
log_std = self.log_std_linear(x) |
|
|
68 |
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) |
|
|
69 |
|
|
|
70 |
return mean, log_std, h_current, x, rnn_in |
|
|
71 |
|
|
|
72 |
def sample(self, state, h_prev, sampling, len_seq=None): |
|
|
73 |
|
|
|
74 |
mean, log_std, h_current, x, rnn_in = self.forward(state, h_prev, sampling, len_seq) |
|
|
75 |
#if sampling == False; then mask the mean and log_std using len_seq |
|
|
76 |
if sampling == False: |
|
|
77 |
assert mean.size()[1] == log_std.size()[1], "There is a mismatch between and mean and sigma Sl_max" |
|
|
78 |
sl_max = mean.size()[1] |
|
|
79 |
with torch.no_grad(): |
|
|
80 |
for seq_idx, k in enumerate(len_seq): |
|
|
81 |
for j in range(1, sl_max + 1): |
|
|
82 |
if j <= k: |
|
|
83 |
if seq_idx == 0 and j == 1: |
|
|
84 |
mask_seq = torch.tensor([True], dtype=bool) |
|
|
85 |
else: |
|
|
86 |
mask_seq = torch.cat((mask_seq, torch.tensor([True])), dim=0) |
|
|
87 |
else: |
|
|
88 |
mask_seq = torch.cat((mask_seq, torch.tensor([False])), dim=0) |
|
|
89 |
#The mask has been created, Now filter the mean and sigma using this mask |
|
|
90 |
mean = mean.reshape(-1, mean.size()[-1])[mask_seq] |
|
|
91 |
log_std = log_std.reshape(-1, log_std.size()[-1])[mask_seq] |
|
|
92 |
|
|
|
93 |
if sampling == True: |
|
|
94 |
mask_seq = [] #If sampling is True return a dummy mask seq |
|
|
95 |
|
|
|
96 |
std = log_std.exp() |
|
|
97 |
|
|
|
98 |
# white noise |
|
|
99 |
normal = Normal(mean, std) |
|
|
100 |
noise = normal.rsample() |
|
|
101 |
|
|
|
102 |
# reparameterization trick |
|
|
103 |
y_t = torch.tanh(noise) |
|
|
104 |
action = y_t * self.action_scale + self.action_bias |
|
|
105 |
log_prob = normal.log_prob(noise) |
|
|
106 |
|
|
|
107 |
# Enforce the action_bounds |
|
|
108 |
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) |
|
|
109 |
log_prob = log_prob.sum(1, keepdim=True) |
|
|
110 |
mean = torch.tanh(mean) * self.action_scale + self.action_bias |
|
|
111 |
|
|
|
112 |
return action, log_prob, mean, h_current, mask_seq, x, rnn_in |
|
|
113 |
|
|
|
114 |
def forward_for_simple_dynamics(self, state, h_prev, sampling, len_seq= None): |
|
|
115 |
|
|
|
116 |
x = F.tanh(self.linear1(state)) |
|
|
117 |
|
|
|
118 |
#Tap the output of the first linear layer |
|
|
119 |
x_l1 = x |
|
|
120 |
|
|
|
121 |
if sampling == False: |
|
|
122 |
assert len_seq!=None, "Proved the len_seq" |
|
|
123 |
x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False) |
|
|
124 |
|
|
|
125 |
x, _ = self.rnn(x, (h_prev)) |
|
|
126 |
|
|
|
127 |
if sampling == False: |
|
|
128 |
x, _ = pad_packed_sequence(x, batch_first= True) |
|
|
129 |
|
|
|
130 |
# x = F.relu(x) |
|
|
131 |
|
|
|
132 |
return x, x_l1 |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
def forward_lstm(self, state, h_prev, sampling, len_seq= None): |
|
|
136 |
|
|
|
137 |
x = F.tanh(self.linear1(state)) |
|
|
138 |
|
|
|
139 |
if sampling == False: |
|
|
140 |
assert len_seq!=None, "Proved the len_seq" |
|
|
141 |
|
|
|
142 |
x = pack_padded_sequence(x, len_seq, batch_first= True, enforce_sorted= False) |
|
|
143 |
|
|
|
144 |
x, (h_current) = self.rnn(x, (h_prev)) |
|
|
145 |
|
|
|
146 |
if sampling == False: |
|
|
147 |
x, len_x_seq = pad_packed_sequence(x, batch_first= True) |
|
|
148 |
|
|
|
149 |
if sampling == True: |
|
|
150 |
x = x.squeeze(1) |
|
|
151 |
|
|
|
152 |
return x |
|
|
153 |
|
|
|
154 |
def forward_for_neural_pert(self, state, h_prev, neural_pert= None): |
|
|
155 |
|
|
|
156 |
x = F.tanh(self.linear1(state)) |
|
|
157 |
|
|
|
158 |
#Tap RNN input for fixedpoint analysis |
|
|
159 |
rnn_in = x |
|
|
160 |
|
|
|
161 |
x, (h_current) = self.rnn(x, (h_prev)) |
|
|
162 |
|
|
|
163 |
#Add the neural perturbation to the RNN output |
|
|
164 |
x = x+neural_pert |
|
|
165 |
|
|
|
166 |
x = x.squeeze(1) |
|
|
167 |
|
|
|
168 |
mean = self.mean_linear(x) |
|
|
169 |
log_std = self.log_std_linear(x) |
|
|
170 |
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) |
|
|
171 |
|
|
|
172 |
|
|
|
173 |
action = torch.tanh(mean) * self.action_scale + self.action_bias |
|
|
174 |
|
|
|
175 |
return action.detach().cpu().numpy()[0], h_current.detach(), x.detach().cpu().numpy(), rnn_in.detach().cpu().numpy() |
|
|
176 |
|
|
|
177 |
class Critic(nn.Module): |
|
|
178 |
def __init__(self, num_inputs, num_actions, hidden_dim): |
|
|
179 |
super(Critic, self).__init__() |
|
|
180 |
|
|
|
181 |
# Q1 architecture |
|
|
182 |
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim) |
|
|
183 |
self.linear2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
184 |
self.linear3 = nn.Linear(hidden_dim, 1) |
|
|
185 |
|
|
|
186 |
# Q2 architecture |
|
|
187 |
self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim) |
|
|
188 |
self.linear5 = nn.Linear(hidden_dim, hidden_dim) |
|
|
189 |
self.linear6 = nn.Linear(hidden_dim, 1) |
|
|
190 |
|
|
|
191 |
self.apply(weights_init_) |
|
|
192 |
|
|
|
193 |
def forward(self, state, action): |
|
|
194 |
|
|
|
195 |
xu = torch.cat([state, action], 1) |
|
|
196 |
|
|
|
197 |
x1 = F.relu(self.linear1(xu)) |
|
|
198 |
x1 = F.relu(self.linear2(x1)) |
|
|
199 |
x1 = self.linear3(x1) |
|
|
200 |
|
|
|
201 |
x2 = F.relu(self.linear4(xu)) |
|
|
202 |
x2 = F.relu(self.linear5(x2)) |
|
|
203 |
x2 = self.linear6(x2) |
|
|
204 |
|
|
|
205 |
return x1, x2 |