[868c5d]: / bin / metrics / test / test.lua

Download this file

36 lines (20 with data), 773 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
require 'torch'
metrics = require 'metrics'
gfx = require 'gfx.js'
resp = torch.DoubleTensor { -0.9, -0.8, -0.8, -0.5, -0.1, 0.0, 0.2, 0.2, 0.51, 0.74, 0.89}
labels = torch.IntTensor { -1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1}
roc_points, thresholds = metrics.roc.points(resp, labels)
area = metrics.roc.area(roc_points)
assert(area >=0.7 and area <= 0.75, "unexpected area under ROC")
print(roc_points)
print(thresholds)
print(area)
gfx.chart(roc_points)
resp = torch.load('resp_1.dat')
labels = torch.load('labels.dat')
roc_points, thresholds = metrics.roc.points(resp, labels)
area = metrics.roc.area(roc_points)
assert(area >=0.49 and area <= 0.51, "unexpected area under ROC")
--print(roc_points)
print(area)
gfx.chart(roc_points)