|
a |
|
b/code/train.lua |
|
|
1 |
function train(net, criterion, classes, classList, imagePaths, batchSize, learningRate, learningRateDecay, weightDecay, momentum, maxIteration, classRatio, augment) |
|
|
2 |
c = os.clock() |
|
|
3 |
t = os.time() |
|
|
4 |
|
|
|
5 |
dofile("randRotateMirror.lua") |
|
|
6 |
|
|
|
7 |
-- compute size of each batch |
|
|
8 |
batchSizes, numBatches = getBatchSizes(classes, classList, batchSize) |
|
|
9 |
|
|
|
10 |
-- shuffle the images |
|
|
11 |
|
|
|
12 |
local temp = {} |
|
|
13 |
for i=1,#classes do |
|
|
14 |
local perm = torch.randperm(classList[i]:size(1)) |
|
|
15 |
temp[i] = torch.LongTensor(classList[i]:size(1)) |
|
|
16 |
for j=1,classList[i]:size(1) do |
|
|
17 |
temp[i][j] = classList[i][perm[j]] |
|
|
18 |
end |
|
|
19 |
end |
|
|
20 |
classList = temp |
|
|
21 |
|
|
|
22 |
-- train |
|
|
23 |
print("# StochasticGradient: training") |
|
|
24 |
|
|
|
25 |
net:training() |
|
|
26 |
|
|
|
27 |
params, gradParams = net:getParameters() |
|
|
28 |
optimState = {} |
|
|
29 |
optimState.learningRate = learningRate |
|
|
30 |
optimState.learningRateDecay = learningRateDecay |
|
|
31 |
optimState.weightDecay = weightDecay |
|
|
32 |
optimState.momentum = momentum |
|
|
33 |
|
|
|
34 |
--while true do |
|
|
35 |
for epoch = 1, maxIteration do |
|
|
36 |
c1 = os.clock() |
|
|
37 |
t1 = os.time() |
|
|
38 |
|
|
|
39 |
local currentError = 0 |
|
|
40 |
|
|
|
41 |
local sampleSum = {} |
|
|
42 |
for i=1,#classes do |
|
|
43 |
sampleSum[i] = 0 |
|
|
44 |
end |
|
|
45 |
|
|
|
46 |
for i=1,numBatches do |
|
|
47 |
t2 = os.time() |
|
|
48 |
c2 = os.clock() |
|
|
49 |
|
|
|
50 |
-- split classList into batches |
|
|
51 |
local sampleList = {} |
|
|
52 |
for j=1,#classes do |
|
|
53 |
sampleList[j] = classList[j][{{sampleSum[j] + 1, sampleSum[j] + batchSizes[j][i]}}] |
|
|
54 |
sampleSum[j] = sampleSum[j] + batchSizes[j][i] |
|
|
55 |
end |
|
|
56 |
|
|
|
57 |
-- get dataset from sampleList |
|
|
58 |
local dataset = getSample(classes, sampleList, imagePaths) |
|
|
59 |
|
|
|
60 |
-- or get a random batch |
|
|
61 |
-- local dataset = getRandomSample(classes, batchSize, classList, imagePaths) |
|
|
62 |
|
|
|
63 |
-- augment the training set with random rotations and mirroring |
|
|
64 |
if augment then |
|
|
65 |
dataset = randRotateMirror(dataset) |
|
|
66 |
end |
|
|
67 |
|
|
|
68 |
if cudaFlag then |
|
|
69 |
dataset.data = dataset.data:cuda() |
|
|
70 |
dataset.label = dataset.label:cuda() |
|
|
71 |
end |
|
|
72 |
|
|
|
73 |
local input = dataset.data |
|
|
74 |
local target = dataset.label |
|
|
75 |
|
|
|
76 |
function feval(params) |
|
|
77 |
gradParams:zero() |
|
|
78 |
|
|
|
79 |
local outputs = net:forward(input) |
|
|
80 |
local loss = criterion:forward(outputs, target) |
|
|
81 |
local dloss_doutputs = criterion:backward(outputs, target) |
|
|
82 |
net:backward(input, dloss_doutputs) |
|
|
83 |
|
|
|
84 |
return loss, gradParams |
|
|
85 |
end |
|
|
86 |
_, fs = optim.sgd(feval, params, optimState) |
|
|
87 |
|
|
|
88 |
print('Epoch = ' .. epoch .. ' of ' .. maxIteration) |
|
|
89 |
print('Batch = ' .. i .. ' of ' .. numBatches) |
|
|
90 |
print('Error = ' .. fs[1]) |
|
|
91 |
print('CPU batch time = ' .. os.clock()-c2 .. ' seconds') |
|
|
92 |
print('Actual batch time (rounded) = ' .. os.time()-t2 .. ' seconds') |
|
|
93 |
if epochClock then |
|
|
94 |
print('CPU epoch time = ' .. epochClock .. ' seconds') |
|
|
95 |
print('Actual epoch time (rounded) = ' .. epochTime .. ' seconds') |
|
|
96 |
end |
|
|
97 |
print('') |
|
|
98 |
end |
|
|
99 |
|
|
|
100 |
epochClock = os.clock()-c1 |
|
|
101 |
epochTime = os.time()-t1 |
|
|
102 |
end |
|
|
103 |
|
|
|
104 |
totalClock = os.clock()-c |
|
|
105 |
totalTime = os.time()-t |
|
|
106 |
print('Total CPU time = ' .. totalClock .. ' seconds') |
|
|
107 |
print('Total actual time (rounded) ' .. totalTime .. ' seconds') |
|
|
108 |
end |