Diff of /tests/test_metrics.py [000000] .. [3b722e]

Switch to unified view

a b/tests/test_metrics.py
1
import numpy as np
2
from numpy.testing import assert_almost_equal
3
4
from oddt.metrics import (roc_auc, roc_log_auc, random_roc_log_auc,
5
                          enrichment_factor, rie, bedroc,
6
                          rmse, standard_deviation_error)
7
8
9
np.random.seed(42)
10
11
# Generate test data for classification
12
classes = np.array([0] * 90000 + [1] * 10000)
13
# poorly separated
14
poor_classes = np.random.rand(100000) * 100
15
16
# well separated
17
good_classes = np.concatenate([np.random.rand(90000) * 10 + 100,
18
                               np.random.rand(10000) * 10 + 1000])
19
20
# Generate test data for regression
21
values = np.arange(100000)
22
poor_values = np.random.rand(100000) * 100    # poorly predicted
23
good_values = np.arange(100000) + np.random.rand(100000)  # correctly predicted
24
25
26
def test_roc_auc():
27
    score = roc_auc(classes, poor_classes)
28
    assert score <= 0.55
29
    assert score >= 0.45
30
31
    assert roc_auc(classes, good_classes, ascending_score=True) == 0.0
32
    assert roc_auc(classes, good_classes, ascending_score=False) == 1.0
33
34
35
def test_roc_log_auc():
36
    random_score = random_roc_log_auc()
37
    score = roc_log_auc(classes, poor_classes)
38
    assert np.abs(score - random_score) < 0.01
39
40
    assert roc_log_auc(classes, good_classes, ascending_score=True) == 0
41
    assert roc_log_auc(classes, good_classes, ascending_score=False) == 1
42
43
44
def test_enrichment():
45
    order = sorted(range(len(poor_classes)), key=lambda k: poor_classes[k],
46
                   reverse=True)
47
    ef = enrichment_factor(classes[order], poor_classes[order])
48
    assert ef <= 1.5
49
50
    order = sorted(range(len(good_classes)), key=lambda k: good_classes[k],
51
                   reverse=True)
52
    ef = enrichment_factor(classes[order], good_classes[order])
53
    assert ef == 10
54
55
    ef = enrichment_factor(classes[order], good_classes[order],
56
                           kind='percentage')
57
    assert ef == 1
58
59
60
def test_rmse():
61
    assert rmse(values, poor_values) >= 30
62
    assert rmse(values, good_values) <= 1
63
64
65
def test_standard_deviation_error():
66
    assert standard_deviation_error(values, good_values) < 1.1
67
    assert standard_deviation_error(values, poor_values) > 2e4
68
69
70
def test_rie():
71
    order = sorted(range(len(poor_classes)), key=lambda k: poor_classes[k],
72
                   reverse=True)
73
    rie_score = rie(classes[order], poor_classes[order])
74
    assert rie_score <= 1.1
75
76
    order = sorted(range(len(good_classes)), key=lambda k: good_classes[k],
77
                   reverse=True)
78
    rie_score = rie(classes[order], good_classes[order])
79
    assert_almost_equal(rie_score, 8.646647185)
80
81
82
def test_bedroc():
83
    order = sorted(range(len(poor_classes)), key=lambda k: poor_classes[k],
84
                   reverse=True)
85
    bedroc_score = bedroc(classes[order], poor_classes[order])
86
    assert bedroc_score < 0.2
87
88
    order = sorted(range(len(good_classes)), key=lambda k: good_classes[k],
89
                   reverse=True)
90
    bedroc_score = bedroc(classes[order], good_classes[order])
91
    assert_almost_equal(bedroc_score, 1.0)