Diff of /bin/metrics/roc.lua [000000] .. [868c5d]

Switch to side-by-side view

--- a
+++ b/bin/metrics/roc.lua
@@ -0,0 +1,127 @@
+local roc = {}
+
+
+function roc.splits(responses, labels, neglabel, poslabel)
+  	
+   	local nsamples = responses:size()[1]
+   
+   	-- sort by response value
+   	local responses_sorted, indexes_sorted = torch.sort(responses)
+   	local labels_sorted = labels:index(1, indexes_sorted)
+
+	local true_negatives = 0
+	local false_negatives = 0   	
+
+   	local epsilon = 0.01
+
+	-- a threshold divides the data as follows:
+   	-- 	if response[i] <= threshold: classify sample[i] as belonging to the negative class
+   	-- 	if response[i] > threshold:  classify sample[i] as belonging to the positive class 
+
+	-- Base case, where the threshold is a bit lower than the minimum response for any sample.
+	-- Here, all samples are classified as belonging to the positive class, therefore we have
+	-- zero true negatives and zero false negatives (all are either true positives or false positives)
+	local threshold = responses[1]-epsilon
+	local splits = {}
+
+	-- we are going to start moving through the samples and increasing the threshold
+   	local i = 0
+   	while i<=nsamples do
+		-- if a set of samples have *exactly* this response, we can't distinguish between them.
+		-- Therefore, all samples with that response will be classified as negatives (since response == threshold)
+		-- and depending on their true label, we need to increase either the TN or the FN counters
+		while i+1 <= nsamples and responses_sorted[i+1] == threshold do
+			if labels_sorted[i+1] == neglabel then
+				true_negatives = true_negatives + 1
+			else
+				false_negatives = false_negatives + 1
+			end
+			i = i+1
+		end
+		-- now that we dealt with the "degenerate" situation of having multiple samples with exactly the same response
+		-- coinciding with the current threshold, lets store this threshold and the current TN and FN
+		splits[#splits+1] = {threshold = threshold, true_negatives = true_negatives, false_negatives = false_negatives}	
+
+		-- We can now move on
+		i = i + 1
+		if i<=nsamples and labels_sorted[i] == poslabel then
+			false_negatives = false_negatives + 1
+		else
+			true_negatives = true_negatives + 1
+			-- while we see only negative examples we can keep increasing the threshold, because there is no point in picking 
+			-- a threshold if we can pick a higher one that will increase the amount of true negatives (therefore decreasing the 
+			-- false positives), without causing any additional false negative. 
+			while i+1 <= nsamples and labels_sorted[i+1] == neglabel do
+				true_negatives = true_negatives + 1
+				i = i+1	
+			end
+		end
+		
+		-- new "interesting" threshold  
+		if i<=nsamples then
+			threshold = responses_sorted[i]
+		end
+   	end
+
+   	-- we are now done, lets return the table with all the tuples of {thresholds, true negatives, false negatives}
+   	-- {{threshold_1, TN_1, FN_1},   ... , {threshold_k, TN_k, FP_k}}
+
+   	return splits
+end
+
+
+
+function roc.points(responses, labels, neglabel, poslabel)
+
+        -- default values for arguments
+        poslabel = poslabel or 1
+        neglabel = neglabel or -1
+
+	-- assertions about the data format expected
+	assert(responses:size():size() == 1, "responses should be a 1D vector")
+	assert(labels:size():size() == 1 , "labels should be a 1D vector")
+
+	-- avoid degenerate class definitions
+	assert(poslabel ~= neglabel, "positive and negative class can't have the same label")
+
+	-- assuming labels { neglabel, poslabel }
+	local npositives = torch.sum(torch.eq(labels, poslabel))
+	local nnegatives = torch.sum(torch.eq(labels, neglabel))
+	local nsamples = npositives + nnegatives
+
+	assert(nsamples == responses:size()[1], "labels should contain only " .. neglabel .. " or " .. poslabel .. " values")
+
+	local splits = roc.splits(responses, labels, neglabel, poslabel)
+
+   	local roc_points = torch.Tensor(#splits, 2)
+   	local thresholds = torch.Tensor(#splits, 1)
+
+   	for i=1,#splits do
+   		local false_positives = nnegatives - splits[i].true_negatives
+		local true_positives = npositives - splits[i].false_negatives 
+		local false_positive_rate = 1.0*false_positives/nnegatives
+		local true_positive_rate = 1.0*true_positives/npositives
+		roc_points[#splits - i + 1][1] = false_positive_rate
+		roc_points[#splits - i + 1][2] = true_positive_rate	
+		thresholds[#splits - i + 1][1] = splits[i].threshold
+   	end
+
+   	return roc_points, thresholds
+end
+
+function roc.area(roc_points)
+
+	local area = 0.0 
+	local npoints = roc_points:size()[1]
+
+	for i=1,npoints-1 do
+		local width = (roc_points[i+1][1] - roc_points[i][1])
+		local avg_height = (roc_points[i][2]+roc_points[i+1][2])/2.0
+		area = area + width*avg_height
+	end
+
+	return area
+end
+
+   
+return roc