|
a |
|
b/util/auRoc.lua |
|
|
1 |
--[[ An auRoc class |
|
|
2 |
-- Same form as the torch optim.ConfusionMatrix class |
|
|
3 |
-- output is assumed to be a ranking (e.g. probability value) |
|
|
4 |
-- label is assumed to be 1 or -1 |
|
|
5 |
Example: |
|
|
6 |
auRoc = auRoc.new() -- new matrix |
|
|
7 |
conf:zero() -- reset matrix |
|
|
8 |
for i = 1,N do |
|
|
9 |
conf:add( output, label ) -- accumulate errors |
|
|
10 |
end |
|
|
11 |
print(auRoc:calculateAuc()) |
|
|
12 |
]] |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
local auRoc = torch.class("auRoc") |
|
|
16 |
|
|
|
17 |
function auRoc:__init() |
|
|
18 |
self.target = {} |
|
|
19 |
self.pred = {} |
|
|
20 |
self.roc = 0 |
|
|
21 |
self.auc = 0 |
|
|
22 |
end |
|
|
23 |
|
|
|
24 |
function auRoc:add(prediction, target) |
|
|
25 |
if target == 2 or target == 0 or target == -1 then |
|
|
26 |
target = -1 |
|
|
27 |
elseif target == 1 then |
|
|
28 |
target = 1 |
|
|
29 |
else |
|
|
30 |
print('Incorrect target for auRoc:add(). Exiting') |
|
|
31 |
os.exit() |
|
|
32 |
end |
|
|
33 |
table.insert(self.pred,prediction) |
|
|
34 |
table.insert(self.target,target) |
|
|
35 |
end |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
function auRoc:zero() |
|
|
39 |
self.target = {} |
|
|
40 |
self.pred = {} |
|
|
41 |
self.roc = 0 |
|
|
42 |
self.auc = 0 |
|
|
43 |
end |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
local function tableToTensor(table) |
|
|
47 |
local tensor = torch.Tensor(#table) |
|
|
48 |
for i = 1,#table do |
|
|
49 |
tensor[i] = table[i] |
|
|
50 |
end |
|
|
51 |
return tensor |
|
|
52 |
end |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
local function get_rates(responses, labels) |
|
|
56 |
torch.setdefaulttensortype('torch.FloatTensor') |
|
|
57 |
|
|
|
58 |
responses = torch.Tensor(responses:size()):copy(responses) |
|
|
59 |
labels = torch.Tensor(labels:size()):copy(labels) |
|
|
60 |
|
|
|
61 |
-- assertions about the data format expected |
|
|
62 |
assert(responses:size():size() == 1, "responses should be a 1D vector") |
|
|
63 |
assert(labels:size():size() == 1 , "labels should be a 1D vector") |
|
|
64 |
|
|
|
65 |
-- assuming labels {-1, 1} |
|
|
66 |
local npositives = torch.sum(torch.eq(labels, 1)) |
|
|
67 |
local nnegatives = torch.sum(torch.eq(labels, -1)) |
|
|
68 |
local nsamples = npositives + nnegatives |
|
|
69 |
|
|
|
70 |
assert(nsamples == responses:size()[1], "labels should contain only -1 or 1 values") |
|
|
71 |
|
|
|
72 |
-- sort by response value |
|
|
73 |
local responses_sorted, indexes_sorted = torch.sort(responses,1,true) |
|
|
74 |
local labels_sorted = labels:index(1, indexes_sorted) |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
local found_positives = 0 |
|
|
78 |
local found_negatives = 0 |
|
|
79 |
|
|
|
80 |
local tpr = {0} -- true pos rate |
|
|
81 |
local fpr = {0} -- false pos rate |
|
|
82 |
|
|
|
83 |
for i = 1,nsamples-1 do |
|
|
84 |
if labels_sorted[i] == -1 then |
|
|
85 |
found_negatives = found_negatives + 1 |
|
|
86 |
else |
|
|
87 |
found_positives = found_positives + 1 |
|
|
88 |
end |
|
|
89 |
|
|
|
90 |
table.insert(tpr, found_positives/npositives) |
|
|
91 |
table.insert(fpr, found_negatives/nnegatives) |
|
|
92 |
end |
|
|
93 |
|
|
|
94 |
table.insert(tpr, 1.0) |
|
|
95 |
table.insert(fpr, 1.0) |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
return tpr, fpr |
|
|
99 |
end |
|
|
100 |
|
|
|
101 |
local function find_auc(tpr,fpr) |
|
|
102 |
local area = 0.0 |
|
|
103 |
for i = 2,#tpr do |
|
|
104 |
local xdiff = fpr[i] - fpr[i-1] |
|
|
105 |
local ydiff = tpr[i] - tpr[i-1] |
|
|
106 |
area = area + (xdiff * tpr[i]) |
|
|
107 |
end |
|
|
108 |
return area |
|
|
109 |
end |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
function auRoc:calculateAuc() |
|
|
113 |
local aucPredTens = tableToTensor(self.pred) |
|
|
114 |
local aucTargetTens = tableToTensor(self.target) |
|
|
115 |
|
|
|
116 |
local tpr = nil |
|
|
117 |
local fpr = nil |
|
|
118 |
|
|
|
119 |
tpr,fpr = get_rates(aucPredTens,aucTargetTens) |
|
|
120 |
self.auc = find_auc(tpr,fpr) |
|
|
121 |
|
|
|
122 |
return self.auc |
|
|
123 |
end |