--- a +++ b/util/LSTM.lua @@ -0,0 +1,284 @@ +require 'torch' +require 'nn' + + +local layer, parent = torch.class('nn.LSTM', 'nn.Module') + +-- Implemented from https://github.com/jcjohnson/torch-rnn + +function layer:__init(input_dim, hidden_dim) + parent.__init(self) + + local D, H = input_dim, hidden_dim + self.input_dim, self.hidden_dim = D, H + + self.weight = torch.Tensor(D + H, 4 * H) + self.gradWeight = torch.Tensor(D + H, 4 * H):zero() + self.bias = torch.Tensor(4 * H) + self.gradBias = torch.Tensor(4 * H):zero() + self:reset() + + self.cell = torch.Tensor() -- This will be (N, T, H) + self.gates = torch.Tensor() -- This will be (N, T, 4H) + self.buffer1 = torch.Tensor() -- This will be (N, H) + self.buffer2 = torch.Tensor() -- This will be (N, H) + self.buffer3 = torch.Tensor() -- This will be (1, 4H) + self.grad_a_buffer = torch.Tensor() -- This will be (N, 4H) + + self.h0 = torch.Tensor() + self.c0 = torch.Tensor() + self.remember_states = false + + self.grad_c0 = torch.Tensor() + self.grad_h0 = torch.Tensor() + self.grad_x = torch.Tensor() + self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} +end + + +function layer:reset(std) + if not std then + std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim) + end + self.bias:zero() + self.bias[{{self.hidden_dim + 1, 2 * self.hidden_dim}}]:fill(1) + self.weight:normal(0, std) + return self +end + + +function layer:resetStates() + self.h0 = self.h0.new() + self.c0 = self.c0.new() +end + + +local function check_dims(x, dims) + assert(x:dim() == #dims) + for i, d in ipairs(dims) do + assert(x:size(i) == d) + end +end + + +function layer:_unpack_input(input) + local c0, h0, x = nil, nil, nil + if torch.type(input) == 'table' and #input == 3 then + c0, h0, x = unpack(input) + elseif torch.type(input) == 'table' and #input == 2 then + h0, x = unpack(input) + elseif torch.isTensor(input) then + x = input + else + assert(false, 'invalid input') + end + return c0, h0, x +end + + +function layer:_get_sizes(input, gradOutput) + local c0, h0, x = self:_unpack_input(input) + local N, T = x:size(1), x:size(2) + local H, D = self.hidden_dim, self.input_dim + check_dims(x, {N, T, D}) + if h0 then + check_dims(h0, {N, H}) + end + if c0 then + check_dims(c0, {N, H}) + end + if gradOutput then + check_dims(gradOutput, {N, T, H}) + end + return N, T, D, H +end + + +--[[ +Input: +- c0: Initial cell state, (N, H) +- h0: Initial hidden state, (N, H) +- x: Input sequence, (N, T, D) + +Output: +- h: Sequence of hidden states, (N, T, H) +--]] +function layer:updateOutput(input) + self.recompute_backward = true + local c0, h0, x = self:_unpack_input(input) + local N, T, D, H = self:_get_sizes(input) + + self._return_grad_c0 = (c0 ~= nil) + self._return_grad_h0 = (h0 ~= nil) + if not c0 then + c0 = self.c0 + if c0:nElement() == 0 or not self.remember_states then + c0:resize(N, H):zero() + elseif self.remember_states then + local prev_N, prev_T = self.cell:size(1), self.cell:size(2) + assert(prev_N == N, 'batch sizes must be constant to remember states') + c0:copy(self.cell[{{}, prev_T}]) + end + end + if not h0 then + h0 = self.h0 + if h0:nElement() == 0 or not self.remember_states then + h0:resize(N, H):zero() + elseif self.remember_states then + local prev_N, prev_T = self.output:size(1), self.output:size(2) + assert(prev_N == N, 'batch sizes must be the same to remember states') + h0:copy(self.output[{{}, prev_T}]) + end + end + + local bias_expand = self.bias:view(1, 4 * H):expand(N, 4 * H) + local Wx = self.weight[{{1, D}}] + local Wh = self.weight[{{D + 1, D + H}}] + + local h, c = self.output, self.cell + h:resize(N, T, H):zero() + c:resize(N, T, H):zero() + local prev_h, prev_c = h0, c0 + self.gates:resize(N, T, 4 * H):zero() + for t = 1, T do + local cur_x = x[{{}, t}] + local next_h = h[{{}, t}] + local next_c = c[{{}, t}] + local cur_gates = self.gates[{{}, t}] + cur_gates:addmm(bias_expand, cur_x, Wx) + cur_gates:addmm(prev_h, Wh) + cur_gates[{{}, {1, 3 * H}}]:sigmoid() + cur_gates[{{}, {3 * H + 1, 4 * H}}]:tanh() + local i = cur_gates[{{}, {1, H}}] + local f = cur_gates[{{}, {H + 1, 2 * H}}] + local o = cur_gates[{{}, {2 * H + 1, 3 * H}}] + local g = cur_gates[{{}, {3 * H + 1, 4 * H}}] + next_h:cmul(i, g) + next_c:cmul(f, prev_c):add(next_h) + next_h:tanh(next_c):cmul(o) + prev_h, prev_c = next_h, next_c + end + + return self.output +end + + +function layer:backward(input, gradOutput, scale) + self.recompute_backward = false + scale = scale or 1.0 + assert(scale == 1.0, 'must have scale=1') + local c0, h0, x = self:_unpack_input(input) + if not c0 then c0 = self.c0 end + if not h0 then h0 = self.h0 end + + local grad_c0, grad_h0, grad_x = self.grad_c0, self.grad_h0, self.grad_x + local h, c = self.output, self.cell + local grad_h = gradOutput + + local N, T, D, H = self:_get_sizes(input, gradOutput) + local Wx = self.weight[{{1, D}}] + local Wh = self.weight[{{D + 1, D + H}}] + local grad_Wx = self.gradWeight[{{1, D}}] + local grad_Wh = self.gradWeight[{{D + 1, D + H}}] + local grad_b = self.gradBias + + grad_h0:resizeAs(h0):zero() + grad_c0:resizeAs(c0):zero() + grad_x:resizeAs(x):zero() + local grad_next_h = self.buffer1:resizeAs(h0):zero() + local grad_next_c = self.buffer2:resizeAs(c0):zero() + for t = T, 1, -1 do + local next_h, next_c = h[{{}, t}], c[{{}, t}] + local prev_h, prev_c = nil, nil + if t == 1 then + prev_h, prev_c = h0, c0 + else + prev_h, prev_c = h[{{}, t - 1}], c[{{}, t - 1}] + end + grad_next_h:add(grad_h[{{}, t}]) + + local i = self.gates[{{}, t, {1, H}}] + local f = self.gates[{{}, t, {H + 1, 2 * H}}] + local o = self.gates[{{}, t, {2 * H + 1, 3 * H}}] + local g = self.gates[{{}, t, {3 * H + 1, 4 * H}}] + + local grad_a = self.grad_a_buffer:resize(N, 4 * H):zero() + local grad_ai = grad_a[{{}, {1, H}}] + local grad_af = grad_a[{{}, {H + 1, 2 * H}}] + local grad_ao = grad_a[{{}, {2 * H + 1, 3 * H}}] + local grad_ag = grad_a[{{}, {3 * H + 1, 4 * H}}] + + -- We will use grad_ai, grad_af, and grad_ao as temporary buffers + -- to to compute grad_next_c. We will need tanh_next_c (stored in grad_ai) + -- to compute grad_ao; the other values can be overwritten after we compute + -- grad_next_c + local tanh_next_c = grad_ai:tanh(next_c) + local tanh_next_c2 = grad_af:cmul(tanh_next_c, tanh_next_c) + local my_grad_next_c = grad_ao + my_grad_next_c:fill(1):add(-1, tanh_next_c2):cmul(o):cmul(grad_next_h) + grad_next_c:add(my_grad_next_c) + + -- We need tanh_next_c (currently in grad_ai) to compute grad_ao; after + -- that we can overwrite it. + grad_ao:fill(1):add(-1, o):cmul(o):cmul(tanh_next_c):cmul(grad_next_h) + + -- Use grad_ai as a temporary buffer for computing grad_ag + local g2 = grad_ai:cmul(g, g) + grad_ag:fill(1):add(-1, g2):cmul(i):cmul(grad_next_c) + + -- We don't need any temporary storage for these so do them last + grad_ai:fill(1):add(-1, i):cmul(i):cmul(g):cmul(grad_next_c) + grad_af:fill(1):add(-1, f):cmul(f):cmul(prev_c):cmul(grad_next_c) + + grad_x[{{}, t}]:mm(grad_a, Wx:t()) + grad_Wx:addmm(scale, x[{{}, t}]:t(), grad_a) + grad_Wh:addmm(scale, prev_h:t(), grad_a) + local grad_a_sum = self.buffer3:resize(1, 4 * H):sum(grad_a, 1) + grad_b:add(scale, grad_a_sum) + + grad_next_h:mm(grad_a, Wh:t()) + grad_next_c:cmul(f) + end + grad_h0:copy(grad_next_h) + grad_c0:copy(grad_next_c) + + if self._return_grad_c0 and self._return_grad_h0 then + self.gradInput = {self.grad_c0, self.grad_h0, self.grad_x} + elseif self._return_grad_h0 then + self.gradInput = {self.grad_h0, self.grad_x} + else + self.gradInput = self.grad_x + end + + return self.gradInput +end + + +function layer:clearState() + self.cell:set() + self.gates:set() + self.buffer1:set() + self.buffer2:set() + self.buffer3:set() + self.grad_a_buffer:set() + + self.grad_c0:set() + self.grad_h0:set() + self.grad_x:set() + self.output:set() +end + + +function layer:updateGradInput(input, gradOutput) + if self.recompute_backward then + self:backward(input, gradOutput, 1.0) + end + return self.gradInput +end + + +function layer:accGradParameters(input, gradOutput, scale) + if self.recompute_backward then + self:backward(input, gradOutput, scale) + end +end