|
a |
|
b/util/LSTM.lua |
|
|
1 |
require 'torch' |
|
|
2 |
require 'nn' |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
local layer, parent = torch.class('nn.LSTM', 'nn.Module') |
|
|
6 |
|
|
|
7 |
-- Implemented from https://github.com/jcjohnson/torch-rnn |
|
|
8 |
|
|
|
9 |
function layer:__init(input_dim, hidden_dim) |
|
|
10 |
parent.__init(self) |
|
|
11 |
|
|
|
12 |
local D, H = input_dim, hidden_dim |
|
|
13 |
self.input_dim, self.hidden_dim = D, H |
|
|
14 |
|
|
|
15 |
self.weight = torch.Tensor(D + H, 4 * H) |
|
|
16 |
self.gradWeight = torch.Tensor(D + H, 4 * H):zero() |
|
|
17 |
self.bias = torch.Tensor(4 * H) |
|
|
18 |
self.gradBias = torch.Tensor(4 * H):zero() |
|
|
19 |
self:reset() |
|
|
20 |
|
|
|
21 |
self.cell = torch.Tensor() -- This will be (N, T, H) |
|
|
22 |
self.gates = torch.Tensor() -- This will be (N, T, 4H) |
|
|
23 |
self.buffer1 = torch.Tensor() -- This will be (N, H) |
|
|
24 |
self.buffer2 = torch.Tensor() -- This will be (N, H) |
|
|
25 |
self.buffer3 = torch.Tensor() -- This will be (1, 4H) |
|
|
26 |
self.grad_a_buffer = torch.Tensor() -- This will be (N, 4H) |
|
|
27 |
|
|
|
28 |
self.h0 = torch.Tensor() |
|
|
29 |
self.c0 = torch.Tensor() |
|
|
30 |
self.remember_states = false |
|
|
31 |
|
|
|
32 |
self.grad_c0 = torch.Tensor() |
|
|
33 |
self.grad_h0 = torch.Tensor() |
|
|
34 |
self.grad_x = torch.Tensor() |
|
|
35 |
self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} |
|
|
36 |
end |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
function layer:reset(std) |
|
|
40 |
if not std then |
|
|
41 |
std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim) |
|
|
42 |
end |
|
|
43 |
self.bias:zero() |
|
|
44 |
self.bias[{{self.hidden_dim + 1, 2 * self.hidden_dim}}]:fill(1) |
|
|
45 |
self.weight:normal(0, std) |
|
|
46 |
return self |
|
|
47 |
end |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
function layer:resetStates() |
|
|
51 |
self.h0 = self.h0.new() |
|
|
52 |
self.c0 = self.c0.new() |
|
|
53 |
end |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
local function check_dims(x, dims) |
|
|
57 |
assert(x:dim() == #dims) |
|
|
58 |
for i, d in ipairs(dims) do |
|
|
59 |
assert(x:size(i) == d) |
|
|
60 |
end |
|
|
61 |
end |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
function layer:_unpack_input(input) |
|
|
65 |
local c0, h0, x = nil, nil, nil |
|
|
66 |
if torch.type(input) == 'table' and #input == 3 then |
|
|
67 |
c0, h0, x = unpack(input) |
|
|
68 |
elseif torch.type(input) == 'table' and #input == 2 then |
|
|
69 |
h0, x = unpack(input) |
|
|
70 |
elseif torch.isTensor(input) then |
|
|
71 |
x = input |
|
|
72 |
else |
|
|
73 |
assert(false, 'invalid input') |
|
|
74 |
end |
|
|
75 |
return c0, h0, x |
|
|
76 |
end |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
function layer:_get_sizes(input, gradOutput) |
|
|
80 |
local c0, h0, x = self:_unpack_input(input) |
|
|
81 |
local N, T = x:size(1), x:size(2) |
|
|
82 |
local H, D = self.hidden_dim, self.input_dim |
|
|
83 |
check_dims(x, {N, T, D}) |
|
|
84 |
if h0 then |
|
|
85 |
check_dims(h0, {N, H}) |
|
|
86 |
end |
|
|
87 |
if c0 then |
|
|
88 |
check_dims(c0, {N, H}) |
|
|
89 |
end |
|
|
90 |
if gradOutput then |
|
|
91 |
check_dims(gradOutput, {N, T, H}) |
|
|
92 |
end |
|
|
93 |
return N, T, D, H |
|
|
94 |
end |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
--[[ |
|
|
98 |
Input: |
|
|
99 |
- c0: Initial cell state, (N, H) |
|
|
100 |
- h0: Initial hidden state, (N, H) |
|
|
101 |
- x: Input sequence, (N, T, D) |
|
|
102 |
|
|
|
103 |
Output: |
|
|
104 |
- h: Sequence of hidden states, (N, T, H) |
|
|
105 |
--]] |
|
|
106 |
function layer:updateOutput(input) |
|
|
107 |
self.recompute_backward = true |
|
|
108 |
local c0, h0, x = self:_unpack_input(input) |
|
|
109 |
local N, T, D, H = self:_get_sizes(input) |
|
|
110 |
|
|
|
111 |
self._return_grad_c0 = (c0 ~= nil) |
|
|
112 |
self._return_grad_h0 = (h0 ~= nil) |
|
|
113 |
if not c0 then |
|
|
114 |
c0 = self.c0 |
|
|
115 |
if c0:nElement() == 0 or not self.remember_states then |
|
|
116 |
c0:resize(N, H):zero() |
|
|
117 |
elseif self.remember_states then |
|
|
118 |
local prev_N, prev_T = self.cell:size(1), self.cell:size(2) |
|
|
119 |
assert(prev_N == N, 'batch sizes must be constant to remember states') |
|
|
120 |
c0:copy(self.cell[{{}, prev_T}]) |
|
|
121 |
end |
|
|
122 |
end |
|
|
123 |
if not h0 then |
|
|
124 |
h0 = self.h0 |
|
|
125 |
if h0:nElement() == 0 or not self.remember_states then |
|
|
126 |
h0:resize(N, H):zero() |
|
|
127 |
elseif self.remember_states then |
|
|
128 |
local prev_N, prev_T = self.output:size(1), self.output:size(2) |
|
|
129 |
assert(prev_N == N, 'batch sizes must be the same to remember states') |
|
|
130 |
h0:copy(self.output[{{}, prev_T}]) |
|
|
131 |
end |
|
|
132 |
end |
|
|
133 |
|
|
|
134 |
local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H) |
|
|
135 |
local Wx = self.weight[{{1, D}}] |
|
|
136 |
local Wh = self.weight[{{D + 1, D + H}}] |
|
|
137 |
|
|
|
138 |
local h, c = self.output, self.cell |
|
|
139 |
h:resize(N, T, H):zero() |
|
|
140 |
c:resize(N, T, H):zero() |
|
|
141 |
local prev_h, prev_c = h0, c0 |
|
|
142 |
self.gates:resize(N, T, 4 * H):zero() |
|
|
143 |
for t = 1, T do |
|
|
144 |
local cur_x = x[{{}, t}] |
|
|
145 |
local next_h = h[{{}, t}] |
|
|
146 |
local next_c = c[{{}, t}] |
|
|
147 |
local cur_gates = self.gates[{{}, t}] |
|
|
148 |
cur_gates:addmm(bias_expand, cur_x, Wx) |
|
|
149 |
cur_gates:addmm(prev_h, Wh) |
|
|
150 |
cur_gates[{{}, {1, 3 * H}}]:sigmoid() |
|
|
151 |
cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh() |
|
|
152 |
local i = cur_gates[{{}, {1, H}}] |
|
|
153 |
local f = cur_gates[{{}, {H + 1, 2 * H}}] |
|
|
154 |
local o = cur_gates[{{}, {2 * H + 1, 3 * H}}] |
|
|
155 |
local g = cur_gates[{{}, {3 * H + 1, 4 * H}}] |
|
|
156 |
next_h:cmul(i, g) |
|
|
157 |
next_c:cmul(f, prev_c):add(next_h) |
|
|
158 |
next_h:tanh(next_c):cmul(o) |
|
|
159 |
prev_h, prev_c = next_h, next_c |
|
|
160 |
end |
|
|
161 |
|
|
|
162 |
return self.output |
|
|
163 |
end |
|
|
164 |
|
|
|
165 |
|
|
|
166 |
function layer:backward(input, gradOutput, scale) |
|
|
167 |
self.recompute_backward = false |
|
|
168 |
scale = scale or 1.0 |
|
|
169 |
assert(scale == 1.0, 'must have scale=1') |
|
|
170 |
local c0, h0, x = self:_unpack_input(input) |
|
|
171 |
if not c0 then c0 = self.c0 end |
|
|
172 |
if not h0 then h0 = self.h0 end |
|
|
173 |
|
|
|
174 |
local grad_c0, grad_h0, grad_x = self.grad_c0, self.grad_h0, self.grad_x |
|
|
175 |
local h, c = self.output, self.cell |
|
|
176 |
local grad_h = gradOutput |
|
|
177 |
|
|
|
178 |
local N, T, D, H = self:_get_sizes(input, gradOutput) |
|
|
179 |
local Wx = self.weight[{{1, D}}] |
|
|
180 |
local Wh = self.weight[{{D + 1, D + H}}] |
|
|
181 |
local grad_Wx = self.gradWeight[{{1, D}}] |
|
|
182 |
local grad_Wh = self.gradWeight[{{D + 1, D + H}}] |
|
|
183 |
local grad_b = self.gradBias |
|
|
184 |
|
|
|
185 |
grad_h0:resizeAs(h0):zero() |
|
|
186 |
grad_c0:resizeAs(c0):zero() |
|
|
187 |
grad_x:resizeAs(x):zero() |
|
|
188 |
local grad_next_h = self.buffer1:resizeAs(h0):zero() |
|
|
189 |
local grad_next_c = self.buffer2:resizeAs(c0):zero() |
|
|
190 |
for t = T, 1, -1 do |
|
|
191 |
local next_h, next_c = h[{{}, t}], c[{{}, t}] |
|
|
192 |
local prev_h, prev_c = nil, nil |
|
|
193 |
if t == 1 then |
|
|
194 |
prev_h, prev_c = h0, c0 |
|
|
195 |
else |
|
|
196 |
prev_h, prev_c = h[{{}, t - 1}], c[{{}, t - 1}] |
|
|
197 |
end |
|
|
198 |
grad_next_h:add(grad_h[{{}, t}]) |
|
|
199 |
|
|
|
200 |
local i = self.gates[{{}, t, {1, H}}] |
|
|
201 |
local f = self.gates[{{}, t, {H + 1, 2 * H}}] |
|
|
202 |
local o = self.gates[{{}, t, {2 * H + 1, 3 * H}}] |
|
|
203 |
local g = self.gates[{{}, t, {3 * H + 1, 4 * H}}] |
|
|
204 |
|
|
|
205 |
local grad_a = self.grad_a_buffer:resize(N, 4 * H):zero() |
|
|
206 |
local grad_ai = grad_a[{{}, {1, H}}] |
|
|
207 |
local grad_af = grad_a[{{}, {H + 1, 2 * H}}] |
|
|
208 |
local grad_ao = grad_a[{{}, {2 * H + 1, 3 * H}}] |
|
|
209 |
local grad_ag = grad_a[{{}, {3 * H + 1, 4 * H}}] |
|
|
210 |
|
|
|
211 |
-- We will use grad_ai, grad_af, and grad_ao as temporary buffers |
|
|
212 |
-- to to compute grad_next_c. We will need tanh_next_c (stored in grad_ai) |
|
|
213 |
-- to compute grad_ao; the other values can be overwritten after we compute |
|
|
214 |
-- grad_next_c |
|
|
215 |
local tanh_next_c = grad_ai:tanh(next_c) |
|
|
216 |
local tanh_next_c2 = grad_af:cmul(tanh_next_c, tanh_next_c) |
|
|
217 |
local my_grad_next_c = grad_ao |
|
|
218 |
my_grad_next_c:fill(1):add(-1, tanh_next_c2):cmul(o):cmul(grad_next_h) |
|
|
219 |
grad_next_c:add(my_grad_next_c) |
|
|
220 |
|
|
|
221 |
-- We need tanh_next_c (currently in grad_ai) to compute grad_ao; after |
|
|
222 |
-- that we can overwrite it. |
|
|
223 |
grad_ao:fill(1):add(-1, o):cmul(o):cmul(tanh_next_c):cmul(grad_next_h) |
|
|
224 |
|
|
|
225 |
-- Use grad_ai as a temporary buffer for computing grad_ag |
|
|
226 |
local g2 = grad_ai:cmul(g, g) |
|
|
227 |
grad_ag:fill(1):add(-1, g2):cmul(i):cmul(grad_next_c) |
|
|
228 |
|
|
|
229 |
-- We don't need any temporary storage for these so do them last |
|
|
230 |
grad_ai:fill(1):add(-1, i):cmul(i):cmul(g):cmul(grad_next_c) |
|
|
231 |
grad_af:fill(1):add(-1, f):cmul(f):cmul(prev_c):cmul(grad_next_c) |
|
|
232 |
|
|
|
233 |
grad_x[{{}, t}]:mm(grad_a, Wx:t()) |
|
|
234 |
grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a) |
|
|
235 |
grad_Wh:addmm(scale, prev_h:t(), grad_a) |
|
|
236 |
local grad_a_sum = self.buffer3:resize(1, 4 * H):sum(grad_a, 1) |
|
|
237 |
grad_b:add(scale, grad_a_sum) |
|
|
238 |
|
|
|
239 |
grad_next_h:mm(grad_a, Wh:t()) |
|
|
240 |
grad_next_c:cmul(f) |
|
|
241 |
end |
|
|
242 |
grad_h0:copy(grad_next_h) |
|
|
243 |
grad_c0:copy(grad_next_c) |
|
|
244 |
|
|
|
245 |
if self._return_grad_c0 and self._return_grad_h0 then |
|
|
246 |
self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} |
|
|
247 |
elseif self._return_grad_h0 then |
|
|
248 |
self.gradInput = {self.grad_h0, self.grad_x} |
|
|
249 |
else |
|
|
250 |
self.gradInput = self.grad_x |
|
|
251 |
end |
|
|
252 |
|
|
|
253 |
return self.gradInput |
|
|
254 |
end |
|
|
255 |
|
|
|
256 |
|
|
|
257 |
function layer:clearState() |
|
|
258 |
self.cell:set() |
|
|
259 |
self.gates:set() |
|
|
260 |
self.buffer1:set() |
|
|
261 |
self.buffer2:set() |
|
|
262 |
self.buffer3:set() |
|
|
263 |
self.grad_a_buffer:set() |
|
|
264 |
|
|
|
265 |
self.grad_c0:set() |
|
|
266 |
self.grad_h0:set() |
|
|
267 |
self.grad_x:set() |
|
|
268 |
self.output:set() |
|
|
269 |
end |
|
|
270 |
|
|
|
271 |
|
|
|
272 |
function layer:updateGradInput(input, gradOutput) |
|
|
273 |
if self.recompute_backward then |
|
|
274 |
self:backward(input, gradOutput, 1.0) |
|
|
275 |
end |
|
|
276 |
return self.gradInput |
|
|
277 |
end |
|
|
278 |
|
|
|
279 |
|
|
|
280 |
function layer:accGradParameters(input, gradOutput, scale) |
|
|
281 |
if self.recompute_backward then |
|
|
282 |
self:backward(input, gradOutput, scale) |
|
|
283 |
end |
|
|
284 |
end |