[868c5d]: / bin / metrics / roc_thresholds.lua

Download this file

202 lines (151 with data), 6.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
require 'torch'
-- rounds a real number num to the number having idp values after the dot
function round(num, idp)
local mult = 10^(idp or 0)
return math.floor(num * mult + 0.5) / mult
end
function getIndex(tensorChosen, element, size)
vectorBis = {}
for i = 1,size do
vectorBis[i] = tensorChosen[i];
end
for k, v in pairs(vectorBis)
do if(element == v) then return k; end
end
end
function roc_thresholds(responses, labels)
-- assertions about the data format expected
assert(responses:size():size() == 1, "responses should be a 1D vector")
assert(labels:size():size() == 1 , "labels should be a 1D vector")
-- print("(#responses)[1]="..(#responses)[1].."\t ##labels="..##labels);
-- io.flush();
-- assuming labels {-1, 1}
local npositives = torch.sum(torch.eq(labels, 1))
local nnegatives = torch.sum(torch.eq(labels, -1))
local nsamples = npositives + nnegatives
assert(nsamples == responses:size()[1], "labels should contain only -1 or 1 values")
-- sort by response value
local responses_sorted, indexes_sorted = torch.sort(responses)
local labels_sorted = labels:index(1, indexes_sorted)
local true_negatives = 0
local false_negatives = 0
local true_positives = 0
local false_positives = 0
local epsilon = 0.01
-- a split threshold divides the data as follows:
-- if response[i] <= split: classify sample[i] as belonging to the negative class
-- if response[i] > split: classify sample[i] as belonging to the positive class
-- Base case, where the split threshold is a bit lower than the minimum response for any sample.
-- Here, all samples are classified as belonging to the positive class, therefore we have
-- zero true negatives and zero false negatives (all are either true positives or false positives)
local split = responses[1]-epsilon
local splits = {}
local splitsExtended = {}
local FPplusFN = {}
-- we are going to start moving through the samples and increasing the split threshold
local i = 0
while i<=nsamples do
-- if a set of samples have *exactly* this response, we can't distinguish between them.
-- Therefore, all samples with that response will be classified as negatives (since response == split)
-- and depending on their true label, we need to increase either the TN or the FN counters
while i+1 <= nsamples and responses_sorted[i+1] == split do
if labels_sorted[i+1] == -1 then
true_negatives = true_negatives + 1
else
false_negatives = false_negatives + 1
end
if i>=1 and i<= nsamples then
FPplusFN[i] = false_positives + false_negatives
-- io.write("\nresponses_sorted["..i.."]="..responses_sorted[i]);
-- io.write(" FPplusFN["..i.."]="..(false_positives + false_negatives))
-- io.write(" false_positives="..false_positives);
-- io.write(" false_negatives="..false_negatives.."\n");
-- io.flush()
end
i = i+1
end
-- now that we dealt with the "degenerate" situation of having multiple samples with exactly the same response
-- coinciding with the current threshold, lets store this threshold and the current TN and FN
splits[#splits+1] = {split, true_negatives, false_negatives}
true_positives = npositives - false_negatives
false_positives = nnegatives - true_negatives
splitsExtended[#splitsExtended+1] = {split, true_negatives, false_negatives, true_positives, false_positives, false_positive_rate, true_positive_rate}
if i>=1 and i<= nsamples then
FPplusFN[i] = false_positives + false_negatives
end
-- We can now move on
i = i + 1
if i<=nsamples and labels_sorted[i] == 1 then
false_negatives = false_negatives + 1
if i>=1 and i<= nsamples then
FPplusFN[i] = false_positives + false_negatives
end
else
true_negatives = true_negatives + 1
-- while we see only negative examples we can keep increasing the threshold, because there is no point in picking
-- a threshold if we can pick a higher one that will increase the amount of true negatives (therefore decreasing the
-- false positives), without causing any additional false negative.
while i+1 <= nsamples and labels_sorted[i+1] == -1 do
true_negatives = true_negatives + 1
if i>=1 and i<= nsamples then
FPplusFN[i] = false_positives + false_negatives
end
i = i+1
end
end
if i>=1 and i<= nsamples then
FPplusFN[i] = false_positives + false_negatives
end
-- new "insteresting" split threshold
if i<=nsamples then
split = responses_sorted[i]
end
end
-- we are now done, lets return the table with all the tuples of {thresholds, true negatives, false negatives}
-- {{split_1, TN_1, FN_1}, ... , {split_k, TN_k, FP_k}}
--for i=1,#FPplusFN do print("FPplusFN["..i.."]="..FPplusFN[i]); end
-- -- print("(#responses_sorted)[1]="..(#responses_sorted)[1]);
print("#FPplusFN="..#FPplusFN)
table.sort(FPplusFN)
minError = FPplusFN[1];
minErrorIndex = getIndex(FPplusFN, minError, #FPplusFN);
globalThreshold = responses_sorted[minErrorIndex]
print("minError="..minError.."\tminErrorIndex="..minErrorIndex .."\tglobalThreshold ="..globalThreshold);
-- if (#responses_sorted)[1] ~= #FPplusFN then print("((#responses_sorted)[1] ~= #FPplusFN) Error: not all the FPplusFN values have been computed"); os.exit(); end
globalMinFPplusFN = globalThreshold
return {splitsExtended, globalMinFPplusFN}
end
function areaNew(newFPrateArray, newTPrateArray)
local area = 0.0
for i=1,#newFPrateArray do
if i>=2 then
area = area + (newFPrateArray[i] - newFPrateArray[i-1])*((newTPrateArray[i]+newTPrateArray[i-1])/2.0)
end
end
area =tonumber(1+area)
--print("areaNew="..area);
return area
end
-- #######################################
--[[
resp = torch.DoubleTensor { -0.9, -0.8, -0.8, -0.5, -0.1, 0.0, 0.2, 0.2, 0.51, 0.74, 0.89}
labels = torch.IntTensor { -1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1}
tp_rate = {}
fp_rate = {}
local roc_points = torch.Tensor((#resp)[1], 2)
splits = roc_thresholds(resp, labels)
print("th.\tTNs\tFNs\tTPs\tFPs\t#\tTPrate\tFPrate");
for i = 1, #splits do
thresholds = splits[i][1]
tn = splits[i][2]
fn = splits[i][3]
tp = splits[i][4]
fp = splits[i][5]
tp_rate[i] = round(tp / (tp+fn),2)
fp_rate[i] = round(fp / (fp+tn),2)
roc_points[i][1] = tp_rate[i]
roc_points[i][2] = fp_rate[i]
print(thresholds.."\t"..tn.."\t"..fn.."\t"..tp.."\t"..fp.."\t#\t"..tp_rate[i].."\t"..fp_rate[i])
end
areaNew(tp_rate,fp_rate)
]]