a b/bin/metrics/test/test.lua
1
require 'torch'
2
metrics = require 'metrics'
3
gfx = require 'gfx.js'
4
5
resp = torch.DoubleTensor { -0.9, -0.8, -0.8, -0.5, -0.1, 0.0, 0.2, 0.2, 0.51, 0.74, 0.89}
6
labels = torch.IntTensor  {   -1,   -1,    1,   -1,   -1,   1,   1,  -1,   -1,    1,    1}
7
8
roc_points, thresholds = metrics.roc.points(resp, labels)
9
area = metrics.roc.area(roc_points)
10
11
assert(area >=0.7 and area <= 0.75, "unexpected area under ROC")
12
13
print(roc_points)
14
print(thresholds)
15
print(area)
16
17
gfx.chart(roc_points)
18
19
20
resp = torch.load('resp_1.dat')
21
labels = torch.load('labels.dat')
22
23
roc_points, thresholds = metrics.roc.points(resp, labels)
24
area = metrics.roc.area(roc_points)
25
26
assert(area >=0.49 and area <= 0.51, "unexpected area under ROC")
27
28
--print(roc_points)
29
print(area)
30
31
gfx.chart(roc_points)
32
33
34
35