[6d0c6b]: / util / OneHot.lua

Download this file

24 lines (21 with data), 884 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
local OneHot, parent = torch.class('OneHot', 'nn.Module')
-- adapted from https://github.com/karpathy/char-rnn/blob/master/util/OneHot.lua
function OneHot:__init(outputSize)
parent.__init(self)
self.outputSize = outputSize
-- We'll construct one-hot encodings by using the index method to
-- reshuffle the rows of an identity matrix. To avoid recreating
-- it every iteration we'll cache it.
self._eye = torch.zeros(outputSize+1,outputSize)
self._eye[{{1,outputSize},{1,outputSize}}] = torch.eye(outputSize)
self._eye[outputSize+1] = torch.zeros(outputSize)
end
function OneHot:updateOutput(input)
self.output:resize(input:size(1), input:size(2), self.outputSize):zero()
for i = 1,input:size(1) do
self._eye = self._eye:float()
local longInput = input[i]:long()
self.output[i]:copy(self._eye:index(1, longInput))
end
return self.output
end