Diff of /code/autoencoder.lua [000000] .. [b758a2]

Switch to unified view

a b/code/autoencoder.lua
1
function convnet2autoencoder(inNet)
2
    outNet = inNet:clone()
3
4
    for i=outNet:size(),1,-1 do
5
        if torch.typename(outNet:get(i)) == 'nn.View' then
6
            outNet:remove(i)
7
        elseif torch.typename(net:get(i)) == 'nn.Linear' then
8
            outNet:remove(i)
9
        elseif torch.typename(net:get(i)) == 'nn.LogSoftMax' then
10
            outNet:remove(i)
11
        end
12
    end
13
14
    for i=outNet:size(),1,-1 do
15
        if torch.typename(outNet:get(i)) == 'nn.SpatialMaxPooling' then
16
            local pool_layer = nn.SpatialMaxPooling(2,2,2,2)
17
            outNet:insert(pool_layer,i+1)
18
            outNet:remove(i)
19
            outNet:add(nn.SpatialMaxUnpooling(pool_layer))
20
        elseif torch.typename(outNet:get(i)) == 'nn.SpatialConvolution' then
21
            nInputPlane = outNet:get(i).nOutputPlane
22
            nOutputPlane = outNet:get(i).nInputPlane
23
            kW = outNet:get(i).kW
24
            kH = outNet:get(i).kH
25
            outNet:add(nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH))
26
            outNet:add(nn.ReLU())
27
        end
28
    end
29
30
    return outNet
31
end
32
33
function autoencoder2convnet(net1, net2)
34
    -- get indices for convolution layers for net1
35
    convList1 = {}
36
    j = 1
37
    for i=1,net1:size() do
38
        if torch.typename(net1:get(i)) == 'nn.SpatialConvolution' then
39
            convList1[j] = i
40
            j = j + 1
41
        end
42
    end
43
44
    -- get indices for convolution layers for net2
45
    convList2 = {}
46
    j=1
47
    for i=1,net2:size() do
48
        if torch.typename(net2:get(i)) == 'nn.SpatialConvolution' then
49
            convList2[j] = i
50
            j = j + 1
51
        end
52
    end
53
54
    -- copy parameters from net1 to net2
55
    for i=1,#convList1 do
56
        net2:get(convList2[i]).weight = net1:get(convList1[i]).weight
57
        net2:get(convList2[i]).bias = net1:get(convList1[i]).bias
58
    end
59
60
    return net2
61
end