[868c5d]: / bin / metrics_ROC_AUC_computer.lua

Download this file

159 lines (109 with data), 5.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
globalMinFPplusFN_vector = {}
-- function that checks if all the elements of a vector are +1
function checkAllOnes(vector)
local dime = #vector
local total = 0
local flag = true
for i=1,dime do
if round(vector[i],0)~=1 then flag = false end
total = total + vector[i]
end
--print("vector average="..round((total/dime),2));
return flag;
end
-- function that checks if all the elements of a vector are 0
function checkAllZeros(vector)
local dime = #vector
local total = 0
local flag = true
for i=1,dime do
if round(vector[i],0)~=0 then flag = false end
total = total + vector[i]
end
--print("vector average="..round((total/dime),2));
return flag;
end
-- function that creates the ROC area under the curve
function metrics_ROC_AUC_computer(completePredValueVector, truthVector)
-- printVector(completePredValueVector, "completePredValueVector");
-- printVector(truthVector, "truthVector");
-- os.exit();
if checkAllZeros(truthVector)==true then
local successRate=0;
print("ATTENTION: all the ground-truth values area 0.0\t The metrics ROC area will be the success rate");
local countZeros = 0
for u=1,#completePredValueVector do
if(completePredValueVector[u]<0.5) then countZeros = countZeros + 1; end
end
successRate = round(countZeros*100/#completePredValueVector, 3);
return successRate;
end
local timeNewAreaStart0 = os.time();
local tp_rate = {}
local fp_rate = {}
local precision_vect = {}
local recall_vect = {}
ROC = require './metrics/roc_thresholds.lua';
local newVect = fromZeroOneToMinusOnePlusOne(truthVector)
local roc_points = torch.Tensor(#completePredValueVector, 2)
local precision_recall_points = torch.Tensor(#completePredValueVector, 2)
--print("#completePredValueVector="..comma_value(#completePredValueVector).."\t#newVect="..comma_value(#newVect));
local roc_thresholds_output = roc_thresholds(torch.DoubleTensor(completePredValueVector), torch.IntTensor(newVect))
local splits = roc_thresholds_output[1]
local thisThreshold = roc_thresholds_output[2]
globalMinFPplusFN_vector[#globalMinFPplusFN_vector+1] = thisThreshold
--print("th.\tTNs\tFNs\tTPs\tFPs\t#\tTPrate\tFPrate");
for i = 1, #splits do
thresholds = splits[i][1]
tn = splits[i][2]
fn = splits[i][3]
tp = splits[i][4]
fp = splits[i][5]
tp_rate[i] = 0
if ((tp+fn)~=0) then tp_rate[i] = tp / (tp+fn) end
fp_rate[i] = 0
if ((fp+tn)~=0) then fp_rate[i] = fp / (fp+tn) end
roc_points[i][1] = tp_rate[i]
roc_points[i][2] = fp_rate[i]
precision_vect[i] = 0
if ((tp+fp)~=0) then precision_vect[i] = tp/(tp+fp) end
recall_vect[i]= tp_rate[i] -- = tp / (tp+fn)
--io.write(round(precision_vect[i],3).." "..round(recall_vect[i],3).."\n");
--io.flush();
-- print(thresholds.."\t"..tn.."\t"..fn.."\t"..tp.."\t"..fp.."\t#\t"..tp_rate[i].."\t"..fp_rate[i])
end
local area_roc = round(areaNew(tp_rate,fp_rate)*100,2);
print("metrics area_roc = "..area_roc.."%");
if area_roc < 0 then io.stderr:write('ERROR: AUC < 0%, problem ongoing'); return; end
if area_roc > 100 then io.stderr:write('ERROR: AUC > 100%, problem ongoing'); return; end
-- print("#splits= "..#splits.." #precision_vect= "..#precision_vect.." #recall_vect= "..#recall_vect);
-- printVector(precision_vect, "precision_vect");
-- printVector(recall_vect, "recall_vect");
require './sort_two_arrays_from_first.lua';
sortedPrecisionVett, sortedRecallVett = sort_two_arrays_from_first(precision_vect, recall_vect, #precision_vect)
-- printVector(sortedPrecisionVett, "sortedPrecisionVett");
-- printVector(sortedRecallVett, "sortedRecallVett");
local area_precision_recall = round((areaNew(sortedPrecisionVett, sortedRecallVett)-1)*100, 2) ; -- UNDERSTAND WHY -1 ???
print("(beta) metrics area_precision_recall = "..area_precision_recall.."%");
if area_precision_recall < 0 then io.stderr:write('ERROR: PrecisionRecallArea < 0%, problem ongoing'); return; end
if area_precision_recall > 100 then io.stderr:write('ERROR: PrecisionRecallArea > 100%, problem ongoing;'); return; end
-- timeNewAreaFinish = os.time();
-- durationNewAreaTotal = timeNewAreaFinish - timeNewAreaStart;
-- print('\ntotal duration of the new area_roc metrics ROC_AUC_computer function: '.. tonumber(durationNewAreaTotal).. ' seconds');
-- io.flush();
-- print('total duration of the new area_roc metrics ROC_AUC_computer function: '..string.format("%.2d hours, %.2d minutes, %.2d seconds", durationNewAreaTotal/(60*60), durationNewAreaTotal/60%60, durationNewAreaTotal%60));
-- io.flush();
printTime(timeNewAreaStart0, " the new area_roc metrics ROC_AUC_computer function");
return {area_roc, area_precision_recall};
end
-- Function that reads a vector and replace all the occurrences of 0's to occurrences of -1's
function fromZeroOneToMinusOnePlusOne(vector)
newVector = {}
for i=1,#vector do
newVector[i] = vector[i]
if (vector[i] == 0) then
newVector[i] = -1
end
end
return newVector;
end