Diff of /util/LSTM.lua [000000] .. [6d0c6b]

Switch to unified view

a b/util/LSTM.lua
1
require 'torch'
2
require 'nn'
3
4
5
local layer, parent = torch.class('nn.LSTM', 'nn.Module')
6
7
-- Implemented from https://github.com/jcjohnson/torch-rnn
8
9
function layer:__init(input_dim, hidden_dim)
10
  parent.__init(self)
11
12
  local D, H = input_dim, hidden_dim
13
  self.input_dim, self.hidden_dim = D, H
14
15
  self.weight = torch.Tensor(D + H, 4 * H)
16
  self.gradWeight = torch.Tensor(D + H, 4 * H):zero()
17
  self.bias = torch.Tensor(4 * H)
18
  self.gradBias = torch.Tensor(4 * H):zero()
19
  self:reset()
20
21
  self.cell = torch.Tensor()    -- This will be (N, T, H)
22
  self.gates = torch.Tensor()   -- This will be (N, T, 4H)
23
  self.buffer1 = torch.Tensor() -- This will be (N, H)
24
  self.buffer2 = torch.Tensor() -- This will be (N, H)
25
  self.buffer3 = torch.Tensor() -- This will be (1, 4H)
26
  self.grad_a_buffer = torch.Tensor() -- This will be (N, 4H)
27
28
  self.h0 = torch.Tensor()
29
  self.c0 = torch.Tensor()
30
  self.remember_states = false
31
32
  self.grad_c0 = torch.Tensor()
33
  self.grad_h0 = torch.Tensor()
34
  self.grad_x = torch.Tensor()
35
  self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x}
36
end
37
38
39
function layer:reset(std)
40
  if not std then
41
    std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim)
42
  end
43
  self.bias:zero()
44
  self.bias[{{self.hidden_dim + 1, 2 * self.hidden_dim}}]:fill(1)
45
  self.weight:normal(0, std)
46
  return self
47
end
48
49
50
function layer:resetStates()
51
  self.h0 = self.h0.new()
52
  self.c0 = self.c0.new()
53
end
54
55
56
local function check_dims(x, dims)
57
  assert(x:dim() == #dims)
58
  for i, d in ipairs(dims) do
59
    assert(x:size(i) == d)
60
  end
61
end
62
63
64
function layer:_unpack_input(input)
65
  local c0, h0, x = nil, nil, nil
66
  if torch.type(input) == 'table' and #input == 3 then
67
    c0, h0, x = unpack(input)
68
  elseif torch.type(input) == 'table' and #input == 2 then
69
    h0, x = unpack(input)
70
  elseif torch.isTensor(input) then
71
    x = input
72
  else
73
    assert(false, 'invalid input')
74
  end
75
  return c0, h0, x
76
end
77
78
79
function layer:_get_sizes(input, gradOutput)
80
  local c0, h0, x = self:_unpack_input(input)
81
  local N, T = x:size(1), x:size(2)
82
  local H, D = self.hidden_dim, self.input_dim
83
  check_dims(x, {N, T, D})
84
  if h0 then
85
    check_dims(h0, {N, H})
86
  end
87
  if c0 then
88
    check_dims(c0, {N, H})
89
  end
90
  if gradOutput then
91
    check_dims(gradOutput, {N, T, H})
92
  end
93
  return N, T, D, H
94
end
95
96
97
--[[
98
Input:
99
- c0: Initial cell state, (N, H)
100
- h0: Initial hidden state, (N, H)
101
- x: Input sequence, (N, T, D)
102
103
Output:
104
- h: Sequence of hidden states, (N, T, H)
105
--]]
106
function layer:updateOutput(input)
107
  self.recompute_backward = true
108
  local c0, h0, x = self:_unpack_input(input)
109
  local N, T, D, H = self:_get_sizes(input)
110
111
  self._return_grad_c0 = (c0 ~= nil)
112
  self._return_grad_h0 = (h0 ~= nil)
113
  if not c0 then
114
    c0 = self.c0
115
    if c0:nElement() == 0 or not self.remember_states then
116
      c0:resize(N, H):zero()
117
    elseif self.remember_states then
118
      local prev_N, prev_T = self.cell:size(1), self.cell:size(2)
119
      assert(prev_N == N, 'batch sizes must be constant to remember states')
120
      c0:copy(self.cell[{{}, prev_T}])
121
    end
122
  end
123
  if not h0 then
124
    h0 = self.h0
125
    if h0:nElement() == 0 or not self.remember_states then
126
      h0:resize(N, H):zero()
127
    elseif self.remember_states then
128
      local prev_N, prev_T = self.output:size(1), self.output:size(2)
129
      assert(prev_N == N, 'batch sizes must be the same to remember states')
130
      h0:copy(self.output[{{}, prev_T}])
131
    end
132
  end
133
134
  local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H)
135
  local Wx = self.weight[{{1, D}}]
136
  local Wh = self.weight[{{D + 1, D + H}}]
137
138
  local h, c = self.output, self.cell
139
  h:resize(N, T, H):zero()
140
  c:resize(N, T, H):zero()
141
  local prev_h, prev_c = h0, c0
142
  self.gates:resize(N, T, 4 * H):zero()
143
  for t = 1, T do
144
    local cur_x = x[{{}, t}]
145
    local next_h = h[{{}, t}]
146
    local next_c = c[{{}, t}]
147
    local cur_gates = self.gates[{{}, t}]
148
    cur_gates:addmm(bias_expand, cur_x, Wx)
149
    cur_gates:addmm(prev_h, Wh)
150
    cur_gates[{{}, {1, 3 * H}}]:sigmoid()
151
    cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh()
152
    local i = cur_gates[{{}, {1, H}}]
153
    local f = cur_gates[{{}, {H + 1, 2 * H}}]
154
    local o = cur_gates[{{}, {2 * H + 1, 3 * H}}]
155
    local g = cur_gates[{{}, {3 * H + 1, 4 * H}}]
156
    next_h:cmul(i, g)
157
    next_c:cmul(f, prev_c):add(next_h)
158
    next_h:tanh(next_c):cmul(o)
159
    prev_h, prev_c = next_h, next_c
160
  end
161
162
  return self.output
163
end
164
165
166
function layer:backward(input, gradOutput, scale)
167
  self.recompute_backward = false
168
  scale = scale or 1.0
169
  assert(scale == 1.0, 'must have scale=1')
170
  local c0, h0, x = self:_unpack_input(input)
171
  if not c0 then c0 = self.c0 end
172
  if not h0 then h0 = self.h0 end
173
174
  local grad_c0, grad_h0, grad_x = self.grad_c0, self.grad_h0, self.grad_x
175
  local h, c = self.output, self.cell
176
  local grad_h = gradOutput
177
178
  local N, T, D, H = self:_get_sizes(input, gradOutput)
179
  local Wx = self.weight[{{1, D}}]
180
  local Wh = self.weight[{{D + 1, D + H}}]
181
  local grad_Wx = self.gradWeight[{{1, D}}]
182
  local grad_Wh = self.gradWeight[{{D + 1, D + H}}]
183
  local grad_b = self.gradBias
184
185
  grad_h0:resizeAs(h0):zero()
186
  grad_c0:resizeAs(c0):zero()
187
  grad_x:resizeAs(x):zero()
188
  local grad_next_h = self.buffer1:resizeAs(h0):zero()
189
  local grad_next_c = self.buffer2:resizeAs(c0):zero()
190
  for t = T, 1, -1 do
191
    local next_h, next_c = h[{{}, t}], c[{{}, t}]
192
    local prev_h, prev_c = nil, nil
193
    if t == 1 then
194
      prev_h, prev_c = h0, c0
195
    else
196
      prev_h, prev_c = h[{{}, t - 1}], c[{{}, t - 1}]
197
    end
198
    grad_next_h:add(grad_h[{{}, t}])
199
200
    local i = self.gates[{{}, t, {1, H}}]
201
    local f = self.gates[{{}, t, {H + 1, 2 * H}}]
202
    local o = self.gates[{{}, t, {2 * H + 1, 3 * H}}]
203
    local g = self.gates[{{}, t, {3 * H + 1, 4 * H}}]
204
205
    local grad_a = self.grad_a_buffer:resize(N, 4 * H):zero()
206
    local grad_ai = grad_a[{{}, {1, H}}]
207
    local grad_af = grad_a[{{}, {H + 1, 2 * H}}]
208
    local grad_ao = grad_a[{{}, {2 * H + 1, 3 * H}}]
209
    local grad_ag = grad_a[{{}, {3 * H + 1, 4 * H}}]
210
211
    -- We will use grad_ai, grad_af, and grad_ao as temporary buffers
212
    -- to to compute grad_next_c. We will need tanh_next_c (stored in grad_ai)
213
    -- to compute grad_ao; the other values can be overwritten after we compute
214
    -- grad_next_c
215
    local tanh_next_c = grad_ai:tanh(next_c)
216
    local tanh_next_c2 = grad_af:cmul(tanh_next_c, tanh_next_c)
217
    local my_grad_next_c = grad_ao
218
    my_grad_next_c:fill(1):add(-1, tanh_next_c2):cmul(o):cmul(grad_next_h)
219
    grad_next_c:add(my_grad_next_c)
220
221
    -- We need tanh_next_c (currently in grad_ai) to compute grad_ao; after
222
    -- that we can overwrite it.
223
    grad_ao:fill(1):add(-1, o):cmul(o):cmul(tanh_next_c):cmul(grad_next_h)
224
225
    -- Use grad_ai as a temporary buffer for computing grad_ag
226
    local g2 = grad_ai:cmul(g, g)
227
    grad_ag:fill(1):add(-1, g2):cmul(i):cmul(grad_next_c)
228
229
    -- We don't need any temporary storage for these so do them last
230
    grad_ai:fill(1):add(-1, i):cmul(i):cmul(g):cmul(grad_next_c)
231
    grad_af:fill(1):add(-1, f):cmul(f):cmul(prev_c):cmul(grad_next_c)
232
233
    grad_x[{{}, t}]:mm(grad_a, Wx:t())
234
    grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a)
235
    grad_Wh:addmm(scale, prev_h:t(), grad_a)
236
    local grad_a_sum = self.buffer3:resize(1, 4 * H):sum(grad_a, 1)
237
    grad_b:add(scale, grad_a_sum)
238
239
    grad_next_h:mm(grad_a, Wh:t())
240
    grad_next_c:cmul(f)
241
  end
242
  grad_h0:copy(grad_next_h)
243
  grad_c0:copy(grad_next_c)
244
245
  if self._return_grad_c0 and self._return_grad_h0 then
246
    self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x}
247
  elseif self._return_grad_h0 then
248
    self.gradInput = {self.grad_h0, self.grad_x}
249
  else
250
    self.gradInput = self.grad_x
251
  end
252
253
  return self.gradInput
254
end
255
256
257
function layer:clearState()
258
  self.cell:set()
259
  self.gates:set()
260
  self.buffer1:set()
261
  self.buffer2:set()
262
  self.buffer3:set()
263
  self.grad_a_buffer:set()
264
265
  self.grad_c0:set()
266
  self.grad_h0:set()
267
  self.grad_x:set()
268
  self.output:set()
269
end
270
271
272
function layer:updateGradInput(input, gradOutput)
273
  if self.recompute_backward then
274
    self:backward(input, gradOutput, 1.0)
275
  end
276
  return self.gradInput
277
end
278
279
280
function layer:accGradParameters(input, gradOutput, scale)
281
  if self.recompute_backward then
282
    self:backward(input, gradOutput, scale)
283
  end
284
end