|
a |
|
b/bin/metrics/roc.lua |
|
|
1 |
local roc = {} |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
function roc.splits(responses, labels, neglabel, poslabel) |
|
|
5 |
|
|
|
6 |
local nsamples = responses:size()[1] |
|
|
7 |
|
|
|
8 |
-- sort by response value |
|
|
9 |
local responses_sorted, indexes_sorted = torch.sort(responses) |
|
|
10 |
local labels_sorted = labels:index(1, indexes_sorted) |
|
|
11 |
|
|
|
12 |
local true_negatives = 0 |
|
|
13 |
local false_negatives = 0 |
|
|
14 |
|
|
|
15 |
local epsilon = 0.01 |
|
|
16 |
|
|
|
17 |
-- a threshold divides the data as follows: |
|
|
18 |
-- if response[i] <= threshold: classify sample[i] as belonging to the negative class |
|
|
19 |
-- if response[i] > threshold: classify sample[i] as belonging to the positive class |
|
|
20 |
|
|
|
21 |
-- Base case, where the threshold is a bit lower than the minimum response for any sample. |
|
|
22 |
-- Here, all samples are classified as belonging to the positive class, therefore we have |
|
|
23 |
-- zero true negatives and zero false negatives (all are either true positives or false positives) |
|
|
24 |
local threshold = responses[1]-epsilon |
|
|
25 |
local splits = {} |
|
|
26 |
|
|
|
27 |
-- we are going to start moving through the samples and increasing the threshold |
|
|
28 |
local i = 0 |
|
|
29 |
while i<=nsamples do |
|
|
30 |
-- if a set of samples have *exactly* this response, we can't distinguish between them. |
|
|
31 |
-- Therefore, all samples with that response will be classified as negatives (since response == threshold) |
|
|
32 |
-- and depending on their true label, we need to increase either the TN or the FN counters |
|
|
33 |
while i+1 <= nsamples and responses_sorted[i+1] == threshold do |
|
|
34 |
if labels_sorted[i+1] == neglabel then |
|
|
35 |
true_negatives = true_negatives + 1 |
|
|
36 |
else |
|
|
37 |
false_negatives = false_negatives + 1 |
|
|
38 |
end |
|
|
39 |
i = i+1 |
|
|
40 |
end |
|
|
41 |
-- now that we dealt with the "degenerate" situation of having multiple samples with exactly the same response |
|
|
42 |
-- coinciding with the current threshold, lets store this threshold and the current TN and FN |
|
|
43 |
splits[#splits+1] = {threshold = threshold, true_negatives = true_negatives, false_negatives = false_negatives} |
|
|
44 |
|
|
|
45 |
-- We can now move on |
|
|
46 |
i = i + 1 |
|
|
47 |
if i<=nsamples and labels_sorted[i] == poslabel then |
|
|
48 |
false_negatives = false_negatives + 1 |
|
|
49 |
else |
|
|
50 |
true_negatives = true_negatives + 1 |
|
|
51 |
-- while we see only negative examples we can keep increasing the threshold, because there is no point in picking |
|
|
52 |
-- a threshold if we can pick a higher one that will increase the amount of true negatives (therefore decreasing the |
|
|
53 |
-- false positives), without causing any additional false negative. |
|
|
54 |
while i+1 <= nsamples and labels_sorted[i+1] == neglabel do |
|
|
55 |
true_negatives = true_negatives + 1 |
|
|
56 |
i = i+1 |
|
|
57 |
end |
|
|
58 |
end |
|
|
59 |
|
|
|
60 |
-- new "interesting" threshold |
|
|
61 |
if i<=nsamples then |
|
|
62 |
threshold = responses_sorted[i] |
|
|
63 |
end |
|
|
64 |
end |
|
|
65 |
|
|
|
66 |
-- we are now done, lets return the table with all the tuples of {thresholds, true negatives, false negatives} |
|
|
67 |
-- {{threshold_1, TN_1, FN_1}, ... , {threshold_k, TN_k, FP_k}} |
|
|
68 |
|
|
|
69 |
return splits |
|
|
70 |
end |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
|
|
|
74 |
function roc.points(responses, labels, neglabel, poslabel) |
|
|
75 |
|
|
|
76 |
-- default values for arguments |
|
|
77 |
poslabel = poslabel or 1 |
|
|
78 |
neglabel = neglabel or -1 |
|
|
79 |
|
|
|
80 |
-- assertions about the data format expected |
|
|
81 |
assert(responses:size():size() == 1, "responses should be a 1D vector") |
|
|
82 |
assert(labels:size():size() == 1 , "labels should be a 1D vector") |
|
|
83 |
|
|
|
84 |
-- avoid degenerate class definitions |
|
|
85 |
assert(poslabel ~= neglabel, "positive and negative class can't have the same label") |
|
|
86 |
|
|
|
87 |
-- assuming labels { neglabel, poslabel } |
|
|
88 |
local npositives = torch.sum(torch.eq(labels, poslabel)) |
|
|
89 |
local nnegatives = torch.sum(torch.eq(labels, neglabel)) |
|
|
90 |
local nsamples = npositives + nnegatives |
|
|
91 |
|
|
|
92 |
assert(nsamples == responses:size()[1], "labels should contain only " .. neglabel .. " or " .. poslabel .. " values") |
|
|
93 |
|
|
|
94 |
local splits = roc.splits(responses, labels, neglabel, poslabel) |
|
|
95 |
|
|
|
96 |
local roc_points = torch.Tensor(#splits, 2) |
|
|
97 |
local thresholds = torch.Tensor(#splits, 1) |
|
|
98 |
|
|
|
99 |
for i=1,#splits do |
|
|
100 |
local false_positives = nnegatives - splits[i].true_negatives |
|
|
101 |
local true_positives = npositives - splits[i].false_negatives |
|
|
102 |
local false_positive_rate = 1.0*false_positives/nnegatives |
|
|
103 |
local true_positive_rate = 1.0*true_positives/npositives |
|
|
104 |
roc_points[#splits - i + 1][1] = false_positive_rate |
|
|
105 |
roc_points[#splits - i + 1][2] = true_positive_rate |
|
|
106 |
thresholds[#splits - i + 1][1] = splits[i].threshold |
|
|
107 |
end |
|
|
108 |
|
|
|
109 |
return roc_points, thresholds |
|
|
110 |
end |
|
|
111 |
|
|
|
112 |
function roc.area(roc_points) |
|
|
113 |
|
|
|
114 |
local area = 0.0 |
|
|
115 |
local npoints = roc_points:size()[1] |
|
|
116 |
|
|
|
117 |
for i=1,npoints-1 do |
|
|
118 |
local width = (roc_points[i+1][1] - roc_points[i][1]) |
|
|
119 |
local avg_height = (roc_points[i][2]+roc_points[i+1][2])/2.0 |
|
|
120 |
area = area + width*avg_height |
|
|
121 |
end |
|
|
122 |
|
|
|
123 |
return area |
|
|
124 |
end |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
return roc |