a b/bin/metrics/roc.lua
1
local roc = {}
2
3
4
function roc.splits(responses, labels, neglabel, poslabel)
5
    
6
    local nsamples = responses:size()[1]
7
   
8
    -- sort by response value
9
    local responses_sorted, indexes_sorted = torch.sort(responses)
10
    local labels_sorted = labels:index(1, indexes_sorted)
11
12
    local true_negatives = 0
13
    local false_negatives = 0       
14
15
    local epsilon = 0.01
16
17
    -- a threshold divides the data as follows:
18
    --  if response[i] <= threshold: classify sample[i] as belonging to the negative class
19
    --  if response[i] > threshold:  classify sample[i] as belonging to the positive class 
20
21
    -- Base case, where the threshold is a bit lower than the minimum response for any sample.
22
    -- Here, all samples are classified as belonging to the positive class, therefore we have
23
    -- zero true negatives and zero false negatives (all are either true positives or false positives)
24
    local threshold = responses[1]-epsilon
25
    local splits = {}
26
27
    -- we are going to start moving through the samples and increasing the threshold
28
    local i = 0
29
    while i<=nsamples do
30
        -- if a set of samples have *exactly* this response, we can't distinguish between them.
31
        -- Therefore, all samples with that response will be classified as negatives (since response == threshold)
32
        -- and depending on their true label, we need to increase either the TN or the FN counters
33
        while i+1 <= nsamples and responses_sorted[i+1] == threshold do
34
            if labels_sorted[i+1] == neglabel then
35
                true_negatives = true_negatives + 1
36
            else
37
                false_negatives = false_negatives + 1
38
            end
39
            i = i+1
40
        end
41
        -- now that we dealt with the "degenerate" situation of having multiple samples with exactly the same response
42
        -- coinciding with the current threshold, lets store this threshold and the current TN and FN
43
        splits[#splits+1] = {threshold = threshold, true_negatives = true_negatives, false_negatives = false_negatives} 
44
45
        -- We can now move on
46
        i = i + 1
47
        if i<=nsamples and labels_sorted[i] == poslabel then
48
            false_negatives = false_negatives + 1
49
        else
50
            true_negatives = true_negatives + 1
51
            -- while we see only negative examples we can keep increasing the threshold, because there is no point in picking 
52
            -- a threshold if we can pick a higher one that will increase the amount of true negatives (therefore decreasing the 
53
            -- false positives), without causing any additional false negative. 
54
            while i+1 <= nsamples and labels_sorted[i+1] == neglabel do
55
                true_negatives = true_negatives + 1
56
                i = i+1 
57
            end
58
        end
59
        
60
        -- new "interesting" threshold  
61
        if i<=nsamples then
62
            threshold = responses_sorted[i]
63
        end
64
    end
65
66
    -- we are now done, lets return the table with all the tuples of {thresholds, true negatives, false negatives}
67
    -- {{threshold_1, TN_1, FN_1},   ... , {threshold_k, TN_k, FP_k}}
68
69
    return splits
70
end
71
72
73
74
function roc.points(responses, labels, neglabel, poslabel)
75
76
        -- default values for arguments
77
        poslabel = poslabel or 1
78
        neglabel = neglabel or -1
79
80
    -- assertions about the data format expected
81
    assert(responses:size():size() == 1, "responses should be a 1D vector")
82
    assert(labels:size():size() == 1 , "labels should be a 1D vector")
83
84
    -- avoid degenerate class definitions
85
    assert(poslabel ~= neglabel, "positive and negative class can't have the same label")
86
87
    -- assuming labels { neglabel, poslabel }
88
    local npositives = torch.sum(torch.eq(labels, poslabel))
89
    local nnegatives = torch.sum(torch.eq(labels, neglabel))
90
    local nsamples = npositives + nnegatives
91
92
    assert(nsamples == responses:size()[1], "labels should contain only " .. neglabel .. " or " .. poslabel .. " values")
93
94
    local splits = roc.splits(responses, labels, neglabel, poslabel)
95
96
    local roc_points = torch.Tensor(#splits, 2)
97
    local thresholds = torch.Tensor(#splits, 1)
98
99
    for i=1,#splits do
100
        local false_positives = nnegatives - splits[i].true_negatives
101
        local true_positives = npositives - splits[i].false_negatives 
102
        local false_positive_rate = 1.0*false_positives/nnegatives
103
        local true_positive_rate = 1.0*true_positives/npositives
104
        roc_points[#splits - i + 1][1] = false_positive_rate
105
        roc_points[#splits - i + 1][2] = true_positive_rate 
106
        thresholds[#splits - i + 1][1] = splits[i].threshold
107
    end
108
109
    return roc_points, thresholds
110
end
111
112
function roc.area(roc_points)
113
114
    local area = 0.0 
115
    local npoints = roc_points:size()[1]
116
117
    for i=1,npoints-1 do
118
        local width = (roc_points[i+1][1] - roc_points[i][1])
119
        local avg_height = (roc_points[i][2]+roc_points[i+1][2])/2.0
120
        area = area + width*avg_height
121
    end
122
123
    return area
124
end
125
126
   
127
return roc