Diff of /code/trainSimple.lua [000000] .. [b758a2]

Switch to unified view

a b/code/trainSimple.lua
1
require 'torch';
2
require 'nn';
3
require 'optim';
4
5
cudaFlag = true
6
7
if cudaFlag then
8
    require 'cutorch';
9
    require 'cunn';
10
end
11
12
--torch.setnumthreads(4)
13
14
-- parameters
15
local batchSize = 100
16
local learningRate = 0.01
17
local weightDecay = 0.000
18
local maxIteration = 10
19
20
local folder  = '/home/andrew/mitosis/data/mitosis-train-old/'
21
22
dofile("data.lua")
23
local classes, classList, imagePaths = getImagePaths(folder)
24
25
-- split dataset into training and test sets
26
local trainClassList = {}
27
local testClassList = {}
28
--[
29
trainClassList[1] = classList[1][{{1,math.ceil(classList[1]:size(1)/2)}}]
30
trainClassList[2] = classList[2][{{1,math.ceil(classList[2]:size(1)/2)}}]
31
testClassList[1] = classList[1][{{math.ceil(classList[1]:size(1)/2)+1,classList[1]:size(1)}}]
32
testClassList[2] = classList[2][{{math.ceil(classList[2]:size(1)/2)+1,classList[2]:size(1)}}]
33
--]]
34
--[[
35
trainClassList[1] = classList[1][{{1,math.ceil(classList[1]:size(1)/20)}}]
36
trainClassList[2] = classList[2][{{1,math.ceil(classList[2]:size(1)/20)}}]
37
testClassList[1] = classList[1][{{math.ceil(classList[1]:size(1)/20)+1,math.ceil(classList[1]:size(1)/10)}}]
38
testClassList[2] = classList[2][{{math.ceil(classList[2]:size(1)/20)+1,math.ceil(classList[1]:size(1)/10)}}]
39
--]]
40
--trainClassList = classList
41
42
local classRatio = trainClassList[2]:size(1)/trainClassList[1]:size(1)
43
44
-- define the model
45
dofile("/home/andrew/mitosis/models/model.lua")
46
47
-- load the pre-trained model
48
torch.load('/home/andrew/mitosis/data/nets/model-pretrained.t7')
49
50
-- train the network
51
dofile("train.lua")
52
train(net, criterion, classes, trainClassList, imagePaths, batchSize, learningRate, weightDecay, maxIteration, classRatio, false)
53
54
-- save the model
55
torch.save('/home/andrew/mitosis/data/nets/testNet.t7', net)
56
--net = torch.load('/home/andrew/mitosis/nets/testNet.t7')
57
58
-- test the network
59
dofile("test.lua")
60
--test(net, classes, testClassList, imagePaths, batchSize)