--- a +++ b/util/OneHot.lua @@ -0,0 +1,23 @@ +local OneHot, parent = torch.class('OneHot', 'nn.Module') +-- adapted from https://github.com/karpathy/char-rnn/blob/master/util/OneHot.lua + +function OneHot:__init(outputSize) + parent.__init(self) + self.outputSize = outputSize + -- We'll construct one-hot encodings by using the index method to + -- reshuffle the rows of an identity matrix. To avoid recreating + -- it every iteration we'll cache it. + self._eye = torch.zeros(outputSize+1,outputSize) + self._eye[{{1,outputSize},{1,outputSize}}] = torch.eye(outputSize) + self._eye[outputSize+1] = torch.zeros(outputSize) +end + +function OneHot:updateOutput(input) + self.output:resize(input:size(1), input:size(2), self.outputSize):zero() + for i = 1,input:size(1) do + self._eye = self._eye:float() + local longInput = input[i]:long() + self.output[i]:copy(self._eye:index(1, longInput)) + end + return self.output +end