Diff of /code/expand.lua [000000] .. [b758a2]

Switch to unified view

a b/code/expand.lua
1
function expand(net)
2
    convCount = 0
3
    poolCount = 0
4
    for i=1,net:size() do
5
        if torch.typename(net:get(i)) == 'nn.SpatialConvolution' then
6
            convCount = convCount + 1
7
            nInputPlane = net:get(i).nInputPlane
8
            nOutputPlane = net:get(i).nOutputPlane
9
            kW = net:get(i).kW
10
            kH = net:get(i).kH
11
            dilationW = 2^(convCount-1)
12
            dilationH = 2^(convCount-1)
13
            net:insert(nn.SpatialDilatedConvolution(nInputPlane,nOutputPlane,kW,kH,1,1,0,0,dilationW,dilationH), i+1)
14
            net:get(i+1).weight = net:get(i).weight
15
            net:get(i+1).bias = net:get(i).bias
16
            net:remove(i)
17
        elseif torch.typename(net:get(i)) == 'nn.SpatialMaxPooling' then
18
            poolCount = poolCount + 1
19
            kW = net:get(i).kW
20
            kH = net:get(i).kH
21
            dilationW = 2^(poolCount-1)
22
            dilationH = 2^(poolCount-1)
23
            net:insert(nn.SpatialDilatedMaxPooling(kW,kH,1,1,0,0,dilationW,dilationH), i+1)
24
            net:get(i+1).weight = net:get(i).weight
25
            net:get(i+1).bias = net:get(i).bias
26
            net:remove(i)
27
        elseif torch.typename(net:get(i)) == 'nn.View' then
28
            net:insert(nn.Identity(),i+1)
29
            net:remove(i)
30
        elseif torch.typename(net:get(i)) == 'nn.Linear' then
31
            convCount = convCount + 1
32
            j = i - 1
33
            while true do
34
                if torch.typename(net:get(j)) == 'nn.SpatialDilatedConvolution' then
35
                    break
36
                end
37
                j = j - 1
38
            end
39
            local nInputPlane = net:get(j).nOutputPlane
40
41
            local outputSize = net:get(i).weight:size(1)
42
            local inputSize = net:get(i).weight:size(2)
43
44
            local nOutputPlane = outputSize
45
            kW = torch.sqrt(inputSize/nInputPlane)
46
            kH = kW
47
            dilationW = 2^(convCount-1)
48
            dilationH = 2^(convCount-1)
49
50
            net:insert(nn.SpatialDilatedConvolution(nInputPlane,nOutputPlane,kW,kH,1,1,0,0,dilationW,dilationH), i+1)
51
            net:get(i+1).weight = net:get(i).weight:resize(nOutputPlane,nInputPlane,kH,kW)
52
            net:get(i+1).bias = net:get(i).bias
53
            net:remove(i)
54
        elseif torch.typename(net:get(i)) == 'nn.LogSoftMax' then
55
            net:insert(nn.SpatialLogSoftMax(), i+1)
56
            net:remove(i)
57
        end
58
    end
59
    for i=net:size(),1,-1 do
60
        if torch.typename(net:get(i)) == 'nn.Identity' then
61
            net:remove(i)
62
        end
63
    end
64
65
    return net
66
end