Diff of /code/train.lua [000000] .. [b758a2]

Switch to unified view

a b/code/train.lua
1
function train(net, criterion, classes, classList, imagePaths, batchSize, learningRate, learningRateDecay, weightDecay, momentum, maxIteration, classRatio, augment)
2
    c = os.clock()
3
    t = os.time()
4
5
    dofile("randRotateMirror.lua")
6
7
    -- compute size of each batch
8
    batchSizes, numBatches = getBatchSizes(classes, classList, batchSize)   
9
10
    -- shuffle the images 
11
    
12
    local temp = {}
13
    for i=1,#classes do
14
        local perm = torch.randperm(classList[i]:size(1))
15
        temp[i] = torch.LongTensor(classList[i]:size(1)) 
16
        for j=1,classList[i]:size(1) do
17
            temp[i][j] = classList[i][perm[j]]
18
        end
19
    end
20
    classList = temp
21
22
    -- train
23
    print("# StochasticGradient: training")
24
25
    net:training()
26
27
    params, gradParams = net:getParameters()
28
    optimState = {}
29
    optimState.learningRate = learningRate
30
    optimState.learningRateDecay = learningRateDecay
31
    optimState.weightDecay = weightDecay
32
    optimState.momentum = momentum
33
34
    --while true do
35
    for epoch = 1, maxIteration do
36
        c1 = os.clock()
37
        t1 = os.time()
38
39
        local currentError = 0
40
41
        local sampleSum = {}
42
        for i=1,#classes do
43
            sampleSum[i] = 0
44
        end
45
46
        for i=1,numBatches do
47
            t2 = os.time()
48
            c2 = os.clock()
49
50
            -- split classList into batches
51
            local sampleList = {}
52
            for j=1,#classes do
53
                sampleList[j] = classList[j][{{sampleSum[j] + 1, sampleSum[j] + batchSizes[j][i]}}]
54
                sampleSum[j] = sampleSum[j] + batchSizes[j][i]
55
            end
56
57
            -- get dataset from sampleList
58
            local dataset = getSample(classes, sampleList, imagePaths)
59
60
            -- or get a random batch
61
--          local dataset = getRandomSample(classes, batchSize, classList, imagePaths)
62
            
63
            -- augment the training set with random rotations and mirroring 
64
            if augment then
65
                dataset = randRotateMirror(dataset)
66
            end
67
68
            if cudaFlag then
69
                dataset.data = dataset.data:cuda()
70
                dataset.label = dataset.label:cuda()
71
            end
72
73
            local input = dataset.data
74
            local target = dataset.label
75
76
            function feval(params)
77
                gradParams:zero()
78
79
                local outputs = net:forward(input)
80
                local loss = criterion:forward(outputs, target)
81
                local dloss_doutputs = criterion:backward(outputs, target)
82
                net:backward(input, dloss_doutputs)
83
84
                return loss, gradParams
85
            end
86
            _, fs = optim.sgd(feval, params, optimState)
87
88
            print('Epoch = ' .. epoch .. ' of ' .. maxIteration)
89
            print('Batch = ' .. i .. ' of ' .. numBatches)
90
            print('Error = ' .. fs[1])
91
            print('CPU batch time = ' .. os.clock()-c2 .. ' seconds')
92
            print('Actual batch time (rounded) = ' .. os.time()-t2 .. ' seconds')
93
            if epochClock then
94
                print('CPU epoch time = ' .. epochClock .. ' seconds')
95
                print('Actual epoch time (rounded) = ' .. epochTime .. ' seconds')
96
            end
97
            print('')
98
        end
99
        
100
        epochClock = os.clock()-c1
101
        epochTime = os.time()-t1
102
    end
103
104
    totalClock = os.clock()-c
105
    totalTime = os.time()-t
106
    print('Total CPU time = ' .. totalClock .. ' seconds')
107
    print('Total actual time (rounded) ' .. totalTime .. ' seconds')
108
end