Diff of /util/auRoc.lua [000000] .. [6d0c6b]

Switch to unified view

a b/util/auRoc.lua
1
--[[ An auRoc class
2
-- Same form as the torch optim.ConfusionMatrix class
3
-- output is assumed to be a ranking (e.g. probability value)
4
-- label is assumed to be 1 or -1
5
Example:
6
    auRoc = auRoc.new()   -- new matrix
7
    conf:zero()                                              -- reset matrix
8
    for i = 1,N do
9
        conf:add( output, label )         -- accumulate errors
10
    end
11
    print(auRoc:calculateAuc())
12
]]
13
14
15
local auRoc = torch.class("auRoc")
16
17
function auRoc:__init()
18
   self.target = {}
19
   self.pred = {}
20
   self.roc = 0
21
   self.auc = 0
22
end
23
24
function auRoc:add(prediction, target)
25
  if target == 2 or target == 0 or target == -1 then
26
    target = -1
27
  elseif target == 1 then
28
    target = 1
29
  else
30
    print('Incorrect target for auRoc:add(). Exiting')
31
    os.exit()
32
  end
33
  table.insert(self.pred,prediction)
34
  table.insert(self.target,target)
35
end
36
37
38
function auRoc:zero()
39
   self.target = {}
40
   self.pred = {}
41
   self.roc = 0
42
   self.auc = 0
43
end
44
45
46
local function tableToTensor(table)
47
  local tensor = torch.Tensor(#table)
48
  for i = 1,#table do
49
    tensor[i] = table[i]
50
  end
51
  return tensor
52
end
53
54
55
local function get_rates(responses, labels)
56
  torch.setdefaulttensortype('torch.FloatTensor')
57
58
  responses = torch.Tensor(responses:size()):copy(responses)
59
  labels = torch.Tensor(labels:size()):copy(labels)
60
61
   -- assertions about the data format expected
62
   assert(responses:size():size() == 1, "responses should be a 1D vector")
63
   assert(labels:size():size() == 1 , "labels should be a 1D vector")
64
65
   -- assuming labels {-1, 1}
66
   local npositives = torch.sum(torch.eq(labels,  1))
67
   local nnegatives = torch.sum(torch.eq(labels, -1))
68
   local nsamples = npositives + nnegatives
69
70
   assert(nsamples == responses:size()[1], "labels should contain only -1 or 1 values")
71
72
   -- sort by response value
73
   local responses_sorted, indexes_sorted = torch.sort(responses,1,true)
74
   local labels_sorted = labels:index(1, indexes_sorted)
75
76
77
   local found_positives = 0
78
   local found_negatives = 0
79
80
   local tpr = {0} -- true pos rate
81
   local fpr = {0} -- false pos rate
82
83
   for i = 1,nsamples-1 do
84
      if labels_sorted[i] == -1 then
85
         found_negatives = found_negatives + 1
86
      else
87
         found_positives = found_positives + 1
88
      end
89
90
      table.insert(tpr, found_positives/npositives)
91
      table.insert(fpr, found_negatives/nnegatives)
92
   end
93
94
   table.insert(tpr, 1.0)
95
   table.insert(fpr, 1.0)
96
97
98
   return tpr, fpr
99
end
100
101
local function find_auc(tpr,fpr)
102
   local area = 0.0
103
   for i = 2,#tpr do
104
      local xdiff = fpr[i] - fpr[i-1]
105
      local ydiff = tpr[i] - tpr[i-1]
106
      area = area + (xdiff * tpr[i])
107
   end
108
   return area
109
end
110
111
112
function auRoc:calculateAuc()
113
  local aucPredTens = tableToTensor(self.pred)
114
  local aucTargetTens = tableToTensor(self.target)
115
116
  local tpr = nil
117
  local fpr = nil
118
119
  tpr,fpr = get_rates(aucPredTens,aucTargetTens)
120
  self.auc = find_auc(tpr,fpr)
121
122
  return self.auc
123
end