Diff of /bin/weight-init.lua [000000] .. [868c5d]

Switch to side-by-side view

--- a
+++ b/bin/weight-init.lua
@@ -0,0 +1,75 @@
+--
+-- Different weight initialization methods
+--
+-- > model = require('weight-init')(model, 'heuristic')
+--
+require("nn")
+
+
+-- "Efficient backprop"
+-- Yann Lecun, 1998
+local function w_init_heuristic(fan_in, fan_out)
+   return math.sqrt(1/(3*fan_in))
+end
+
+
+-- "Understanding the difficulty of training deep feedforward neural networks"
+-- Xavier Glorot, 2010
+local function w_init_xavier(fan_in, fan_out)
+   return math.sqrt(2/(fan_in + fan_out))
+end
+
+
+-- "Understanding the difficulty of training deep feedforward neural networks"
+-- Xavier Glorot, 2010
+local function w_init_xavier_caffe(fan_in, fan_out)
+   return math.sqrt(1/fan_in)
+end
+
+
+-- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
+-- Kaiming He, 2015
+local function w_init_kaiming(fan_in, fan_out)
+   return math.sqrt(4/(fan_in + fan_out))
+end
+
+
+local function w_init(net, arg)
+   -- choose initialization method
+   local method = nil
+   if     arg == 'heuristic'    then method = w_init_heuristic
+   elseif arg == 'xavier'       then method = w_init_xavier
+   elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe
+   elseif arg == 'kaiming'      then method = w_init_kaiming
+   else
+      assert(false)
+   end
+
+   -- loop over all convolutional modules
+   for i = 1, #net.modules do
+      local m = net.modules[i]
+      if m.__typename == 'nn.SpatialConvolution' then
+         m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
+      elseif m.__typename == 'nn.SpatialConvolutionMM' then
+         m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
+      elseif m.__typename == 'nn.LateralConvolution' then
+         m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
+      elseif m.__typename == 'nn.VerticalConvolution' then
+         m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
+      elseif m.__typename == 'nn.HorizontalConvolution' then
+         m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
+      elseif m.__typename == 'nn.Linear' then
+         m:reset(method(m.weight:size(2), m.weight:size(1)))
+      elseif m.__typename == 'nn.TemporalConvolution' then
+         m:reset(method(m.weight:size(2), m.weight:size(1)))            
+      end
+
+      if m.bias then
+         m.bias:zero()
+      end
+   end
+   return net
+end
+
+
+return w_init
\ No newline at end of file