Diff of /code/test.lua [000000] .. [b758a2]

Switch to side-by-side view

--- a
+++ b/code/test.lua
@@ -0,0 +1,46 @@
+function test(net, classes, testClassList, imagePaths, batchSize)
+	net:evaluate()
+
+	class_performance = {0, 0}
+	correct = 0
+	class_number = {0, 0}
+
+	-- compute size of each batch
+	batchSizes, numBatches, numSamples = getBatchSizes(classes, testClassList, batchSize)
+
+	local sampleSum = {}
+	for i=1,#classes do
+		sampleSum[i] = 0
+	end
+
+	for i=1,numBatches do
+		-- split testClassList into batches
+		sampleList = {}
+		for j=1,#classes do
+			sampleList[j] = testClassList[j][{{sampleSum[j] + 1, sampleSum[j] + batchSizes[j][i]}}]
+			sampleSum[j] = sampleSum[j] + batchSizes[j][i]
+		end
+
+		local testset = getSample(classes, sampleList, imagePaths)
+		if cudaFlag then
+			testset.data = testset.data:cuda()
+			testset.label = testset.label:cuda()
+		end
+
+		for j=1,testset:size() do
+			local groundtruth = testset.label[j]
+			local prediction = net:forward(testset.data[j])
+			local confidences, indices = torch.sort(prediction, true)  -- true means sort in descending order
+			if groundtruth == indices[1] then
+				class_performance[groundtruth] = class_performance[groundtruth] + 1
+				correct = correct + 1
+			end
+			class_number[groundtruth] = class_number[groundtruth] + 1
+		end
+	end
+
+	for i=1,#classes do
+		print(classes[i], 100*class_performance[i]/class_number[i] .. ' %')
+	end
+	print(100*correct/numSamples .. ' %')
+end