Diff of /util/OneHot.lua [000000] .. [6d0c6b]

Switch to unified view

a b/util/OneHot.lua
1
local OneHot, parent = torch.class('OneHot', 'nn.Module')
2
-- adapted from https://github.com/karpathy/char-rnn/blob/master/util/OneHot.lua
3
4
function OneHot:__init(outputSize)
5
  parent.__init(self)
6
  self.outputSize = outputSize
7
  -- We'll construct one-hot encodings by using the index method to
8
  -- reshuffle the rows of an identity matrix. To avoid recreating
9
  -- it every iteration we'll cache it.
10
  self._eye = torch.zeros(outputSize+1,outputSize)
11
  self._eye[{{1,outputSize},{1,outputSize}}] = torch.eye(outputSize)
12
  self._eye[outputSize+1] = torch.zeros(outputSize)
13
end
14
15
function OneHot:updateOutput(input)
16
  self.output:resize(input:size(1), input:size(2), self.outputSize):zero()
17
  for i = 1,input:size(1) do
18
      self._eye = self._eye:float()
19
      local longInput = input[i]:long()
20
      self.output[i]:copy(self._eye:index(1, longInput))
21
  end
22
  return self.output
23
end