[868c5d]: / bin / metrics / roc.lua

Download this file

128 lines (96 with data), 4.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
local roc = {}
function roc.splits(responses, labels, neglabel, poslabel)
local nsamples = responses:size()[1]
-- sort by response value
local responses_sorted, indexes_sorted = torch.sort(responses)
local labels_sorted = labels:index(1, indexes_sorted)
local true_negatives = 0
local false_negatives = 0
local epsilon = 0.01
-- a threshold divides the data as follows:
-- if response[i] <= threshold: classify sample[i] as belonging to the negative class
-- if response[i] > threshold: classify sample[i] as belonging to the positive class
-- Base case, where the threshold is a bit lower than the minimum response for any sample.
-- Here, all samples are classified as belonging to the positive class, therefore we have
-- zero true negatives and zero false negatives (all are either true positives or false positives)
local threshold = responses[1]-epsilon
local splits = {}
-- we are going to start moving through the samples and increasing the threshold
local i = 0
while i<=nsamples do
-- if a set of samples have *exactly* this response, we can't distinguish between them.
-- Therefore, all samples with that response will be classified as negatives (since response == threshold)
-- and depending on their true label, we need to increase either the TN or the FN counters
while i+1 <= nsamples and responses_sorted[i+1] == threshold do
if labels_sorted[i+1] == neglabel then
true_negatives = true_negatives + 1
else
false_negatives = false_negatives + 1
end
i = i+1
end
-- now that we dealt with the "degenerate" situation of having multiple samples with exactly the same response
-- coinciding with the current threshold, lets store this threshold and the current TN and FN
splits[#splits+1] = {threshold = threshold, true_negatives = true_negatives, false_negatives = false_negatives}
-- We can now move on
i = i + 1
if i<=nsamples and labels_sorted[i] == poslabel then
false_negatives = false_negatives + 1
else
true_negatives = true_negatives + 1
-- while we see only negative examples we can keep increasing the threshold, because there is no point in picking
-- a threshold if we can pick a higher one that will increase the amount of true negatives (therefore decreasing the
-- false positives), without causing any additional false negative.
while i+1 <= nsamples and labels_sorted[i+1] == neglabel do
true_negatives = true_negatives + 1
i = i+1
end
end
-- new "interesting" threshold
if i<=nsamples then
threshold = responses_sorted[i]
end
end
-- we are now done, lets return the table with all the tuples of {thresholds, true negatives, false negatives}
-- {{threshold_1, TN_1, FN_1}, ... , {threshold_k, TN_k, FP_k}}
return splits
end
function roc.points(responses, labels, neglabel, poslabel)
-- default values for arguments
poslabel = poslabel or 1
neglabel = neglabel or -1
-- assertions about the data format expected
assert(responses:size():size() == 1, "responses should be a 1D vector")
assert(labels:size():size() == 1 , "labels should be a 1D vector")
-- avoid degenerate class definitions
assert(poslabel ~= neglabel, "positive and negative class can't have the same label")
-- assuming labels { neglabel, poslabel }
local npositives = torch.sum(torch.eq(labels, poslabel))
local nnegatives = torch.sum(torch.eq(labels, neglabel))
local nsamples = npositives + nnegatives
assert(nsamples == responses:size()[1], "labels should contain only " .. neglabel .. " or " .. poslabel .. " values")
local splits = roc.splits(responses, labels, neglabel, poslabel)
local roc_points = torch.Tensor(#splits, 2)
local thresholds = torch.Tensor(#splits, 1)
for i=1,#splits do
local false_positives = nnegatives - splits[i].true_negatives
local true_positives = npositives - splits[i].false_negatives
local false_positive_rate = 1.0*false_positives/nnegatives
local true_positive_rate = 1.0*true_positives/npositives
roc_points[#splits - i + 1][1] = false_positive_rate
roc_points[#splits - i + 1][2] = true_positive_rate
thresholds[#splits - i + 1][1] = splits[i].threshold
end
return roc_points, thresholds
end
function roc.area(roc_points)
local area = 0.0
local npoints = roc_points:size()[1]
for i=1,npoints-1 do
local width = (roc_points[i+1][1] - roc_points[i][1])
local avg_height = (roc_points[i][2]+roc_points[i+1][2])/2.0
area = area + width*avg_height
end
return area
end
return roc