|
a |
|
b/bin/weight-init.lua |
|
|
1 |
-- |
|
|
2 |
-- Different weight initialization methods |
|
|
3 |
-- |
|
|
4 |
-- > model = require('weight-init')(model, 'heuristic') |
|
|
5 |
-- |
|
|
6 |
require("nn") |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
-- "Efficient backprop" |
|
|
10 |
-- Yann Lecun, 1998 |
|
|
11 |
local function w_init_heuristic(fan_in, fan_out) |
|
|
12 |
return math.sqrt(1/(3*fan_in)) |
|
|
13 |
end |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
-- "Understanding the difficulty of training deep feedforward neural networks" |
|
|
17 |
-- Xavier Glorot, 2010 |
|
|
18 |
local function w_init_xavier(fan_in, fan_out) |
|
|
19 |
return math.sqrt(2/(fan_in + fan_out)) |
|
|
20 |
end |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
-- "Understanding the difficulty of training deep feedforward neural networks" |
|
|
24 |
-- Xavier Glorot, 2010 |
|
|
25 |
local function w_init_xavier_caffe(fan_in, fan_out) |
|
|
26 |
return math.sqrt(1/fan_in) |
|
|
27 |
end |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
-- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" |
|
|
31 |
-- Kaiming He, 2015 |
|
|
32 |
local function w_init_kaiming(fan_in, fan_out) |
|
|
33 |
return math.sqrt(4/(fan_in + fan_out)) |
|
|
34 |
end |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
local function w_init(net, arg) |
|
|
38 |
-- choose initialization method |
|
|
39 |
local method = nil |
|
|
40 |
if arg == 'heuristic' then method = w_init_heuristic |
|
|
41 |
elseif arg == 'xavier' then method = w_init_xavier |
|
|
42 |
elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe |
|
|
43 |
elseif arg == 'kaiming' then method = w_init_kaiming |
|
|
44 |
else |
|
|
45 |
assert(false) |
|
|
46 |
end |
|
|
47 |
|
|
|
48 |
-- loop over all convolutional modules |
|
|
49 |
for i = 1, #net.modules do |
|
|
50 |
local m = net.modules[i] |
|
|
51 |
if m.__typename == 'nn.SpatialConvolution' then |
|
|
52 |
m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) |
|
|
53 |
elseif m.__typename == 'nn.SpatialConvolutionMM' then |
|
|
54 |
m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) |
|
|
55 |
elseif m.__typename == 'nn.LateralConvolution' then |
|
|
56 |
m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1)) |
|
|
57 |
elseif m.__typename == 'nn.VerticalConvolution' then |
|
|
58 |
m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) |
|
|
59 |
elseif m.__typename == 'nn.HorizontalConvolution' then |
|
|
60 |
m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) |
|
|
61 |
elseif m.__typename == 'nn.Linear' then |
|
|
62 |
m:reset(method(m.weight:size(2), m.weight:size(1))) |
|
|
63 |
elseif m.__typename == 'nn.TemporalConvolution' then |
|
|
64 |
m:reset(method(m.weight:size(2), m.weight:size(1))) |
|
|
65 |
end |
|
|
66 |
|
|
|
67 |
if m.bias then |
|
|
68 |
m.bias:zero() |
|
|
69 |
end |
|
|
70 |
end |
|
|
71 |
return net |
|
|
72 |
end |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
return w_init |