Diff of /code/test.lua [000000] .. [b758a2]

Switch to unified view

a b/code/test.lua
1
function test(net, classes, testClassList, imagePaths, batchSize)
2
    net:evaluate()
3
4
    class_performance = {0, 0}
5
    correct = 0
6
    class_number = {0, 0}
7
8
    -- compute size of each batch
9
    batchSizes, numBatches, numSamples = getBatchSizes(classes, testClassList, batchSize)
10
11
    local sampleSum = {}
12
    for i=1,#classes do
13
        sampleSum[i] = 0
14
    end
15
16
    for i=1,numBatches do
17
        -- split testClassList into batches
18
        sampleList = {}
19
        for j=1,#classes do
20
            sampleList[j] = testClassList[j][{{sampleSum[j] + 1, sampleSum[j] + batchSizes[j][i]}}]
21
            sampleSum[j] = sampleSum[j] + batchSizes[j][i]
22
        end
23
24
        local testset = getSample(classes, sampleList, imagePaths)
25
        if cudaFlag then
26
            testset.data = testset.data:cuda()
27
            testset.label = testset.label:cuda()
28
        end
29
30
        for j=1,testset:size() do
31
            local groundtruth = testset.label[j]
32
            local prediction = net:forward(testset.data[j])
33
            local confidences, indices = torch.sort(prediction, true)  -- true means sort in descending order
34
            if groundtruth == indices[1] then
35
                class_performance[groundtruth] = class_performance[groundtruth] + 1
36
                correct = correct + 1
37
            end
38
            class_number[groundtruth] = class_number[groundtruth] + 1
39
        end
40
    end
41
42
    for i=1,#classes do
43
        print(classes[i], 100*class_performance[i]/class_number[i] .. ' %')
44
    end
45
    print(100*correct/numSamples .. ' %')
46
end