|
a |
|
b/bin/metrics/test/test.lua |
|
|
1 |
require 'torch' |
|
|
2 |
metrics = require 'metrics' |
|
|
3 |
gfx = require 'gfx.js' |
|
|
4 |
|
|
|
5 |
resp = torch.DoubleTensor { -0.9, -0.8, -0.8, -0.5, -0.1, 0.0, 0.2, 0.2, 0.51, 0.74, 0.89} |
|
|
6 |
labels = torch.IntTensor { -1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1} |
|
|
7 |
|
|
|
8 |
roc_points, thresholds = metrics.roc.points(resp, labels) |
|
|
9 |
area = metrics.roc.area(roc_points) |
|
|
10 |
|
|
|
11 |
assert(area >=0.7 and area <= 0.75, "unexpected area under ROC") |
|
|
12 |
|
|
|
13 |
print(roc_points) |
|
|
14 |
print(thresholds) |
|
|
15 |
print(area) |
|
|
16 |
|
|
|
17 |
gfx.chart(roc_points) |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
resp = torch.load('resp_1.dat') |
|
|
21 |
labels = torch.load('labels.dat') |
|
|
22 |
|
|
|
23 |
roc_points, thresholds = metrics.roc.points(resp, labels) |
|
|
24 |
area = metrics.roc.area(roc_points) |
|
|
25 |
|
|
|
26 |
assert(area >=0.49 and area <= 0.51, "unexpected area under ROC") |
|
|
27 |
|
|
|
28 |
--print(roc_points) |
|
|
29 |
print(area) |
|
|
30 |
|
|
|
31 |
gfx.chart(roc_points) |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
|
|
|
35 |
|