--- a +++ b/bin/weight-init.lua @@ -0,0 +1,75 @@ +-- +-- Different weight initialization methods +-- +-- > model = require('weight-init')(model, 'heuristic') +-- +require("nn") + + +-- "Efficient backprop" +-- Yann Lecun, 1998 +local function w_init_heuristic(fan_in, fan_out) + return math.sqrt(1/(3*fan_in)) +end + + +-- "Understanding the difficulty of training deep feedforward neural networks" +-- Xavier Glorot, 2010 +local function w_init_xavier(fan_in, fan_out) + return math.sqrt(2/(fan_in + fan_out)) +end + + +-- "Understanding the difficulty of training deep feedforward neural networks" +-- Xavier Glorot, 2010 +local function w_init_xavier_caffe(fan_in, fan_out) + return math.sqrt(1/fan_in) +end + + +-- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" +-- Kaiming He, 2015 +local function w_init_kaiming(fan_in, fan_out) + return math.sqrt(4/(fan_in + fan_out)) +end + + +local function w_init(net, arg) + -- choose initialization method + local method = nil + if arg == 'heuristic' then method = w_init_heuristic + elseif arg == 'xavier' then method = w_init_xavier + elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe + elseif arg == 'kaiming' then method = w_init_kaiming + else + assert(false) + end + + -- loop over all convolutional modules + for i = 1, #net.modules do + local m = net.modules[i] + if m.__typename == 'nn.SpatialConvolution' then + m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) + elseif m.__typename == 'nn.SpatialConvolutionMM' then + m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) + elseif m.__typename == 'nn.LateralConvolution' then + m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1)) + elseif m.__typename == 'nn.VerticalConvolution' then + m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) + elseif m.__typename == 'nn.HorizontalConvolution' then + m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) + elseif m.__typename == 'nn.Linear' then + m:reset(method(m.weight:size(2), m.weight:size(1))) + elseif m.__typename == 'nn.TemporalConvolution' then + m:reset(method(m.weight:size(2), m.weight:size(1))) + end + + if m.bias then + m.bias:zero() + end + end + return net +end + + +return w_init \ No newline at end of file