Diff of /bin/weight-init.lua [000000] .. [868c5d]

Switch to unified view

a b/bin/weight-init.lua
1
--
2
-- Different weight initialization methods
3
--
4
-- > model = require('weight-init')(model, 'heuristic')
5
--
6
require("nn")
7
8
9
-- "Efficient backprop"
10
-- Yann Lecun, 1998
11
local function w_init_heuristic(fan_in, fan_out)
12
   return math.sqrt(1/(3*fan_in))
13
end
14
15
16
-- "Understanding the difficulty of training deep feedforward neural networks"
17
-- Xavier Glorot, 2010
18
local function w_init_xavier(fan_in, fan_out)
19
   return math.sqrt(2/(fan_in + fan_out))
20
end
21
22
23
-- "Understanding the difficulty of training deep feedforward neural networks"
24
-- Xavier Glorot, 2010
25
local function w_init_xavier_caffe(fan_in, fan_out)
26
   return math.sqrt(1/fan_in)
27
end
28
29
30
-- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
31
-- Kaiming He, 2015
32
local function w_init_kaiming(fan_in, fan_out)
33
   return math.sqrt(4/(fan_in + fan_out))
34
end
35
36
37
local function w_init(net, arg)
38
   -- choose initialization method
39
   local method = nil
40
   if     arg == 'heuristic'    then method = w_init_heuristic
41
   elseif arg == 'xavier'       then method = w_init_xavier
42
   elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe
43
   elseif arg == 'kaiming'      then method = w_init_kaiming
44
   else
45
      assert(false)
46
   end
47
48
   -- loop over all convolutional modules
49
   for i = 1, #net.modules do
50
      local m = net.modules[i]
51
      if m.__typename == 'nn.SpatialConvolution' then
52
         m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
53
      elseif m.__typename == 'nn.SpatialConvolutionMM' then
54
         m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
55
      elseif m.__typename == 'nn.LateralConvolution' then
56
         m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
57
      elseif m.__typename == 'nn.VerticalConvolution' then
58
         m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
59
      elseif m.__typename == 'nn.HorizontalConvolution' then
60
         m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
61
      elseif m.__typename == 'nn.Linear' then
62
         m:reset(method(m.weight:size(2), m.weight:size(1)))
63
      elseif m.__typename == 'nn.TemporalConvolution' then
64
         m:reset(method(m.weight:size(2), m.weight:size(1)))            
65
      end
66
67
      if m.bias then
68
         m.bias:zero()
69
      end
70
   end
71
   return net
72
end
73
74
75
return w_init