|
a |
|
b/util/OneHot.lua |
|
|
1 |
local OneHot, parent = torch.class('OneHot', 'nn.Module') |
|
|
2 |
-- adapted from https://github.com/karpathy/char-rnn/blob/master/util/OneHot.lua |
|
|
3 |
|
|
|
4 |
function OneHot:__init(outputSize) |
|
|
5 |
parent.__init(self) |
|
|
6 |
self.outputSize = outputSize |
|
|
7 |
-- We'll construct one-hot encodings by using the index method to |
|
|
8 |
-- reshuffle the rows of an identity matrix. To avoid recreating |
|
|
9 |
-- it every iteration we'll cache it. |
|
|
10 |
self._eye = torch.zeros(outputSize+1,outputSize) |
|
|
11 |
self._eye[{{1,outputSize},{1,outputSize}}] = torch.eye(outputSize) |
|
|
12 |
self._eye[outputSize+1] = torch.zeros(outputSize) |
|
|
13 |
end |
|
|
14 |
|
|
|
15 |
function OneHot:updateOutput(input) |
|
|
16 |
self.output:resize(input:size(1), input:size(2), self.outputSize):zero() |
|
|
17 |
for i = 1,input:size(1) do |
|
|
18 |
self._eye = self._eye:float() |
|
|
19 |
local longInput = input[i]:long() |
|
|
20 |
self.output[i]:copy(self._eye:index(1, longInput)) |
|
|
21 |
end |
|
|
22 |
return self.output |
|
|
23 |
end |