|
a |
|
b/tests/test_models.py |
|
|
1 |
import pickle |
|
|
2 |
import numpy as np |
|
|
3 |
|
|
|
4 |
from numpy.testing import assert_array_almost_equal, assert_array_equal |
|
|
5 |
import pytest |
|
|
6 |
|
|
|
7 |
from oddt.scoring.models import classifiers, regressors |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
@pytest.mark.filterwarnings('ignore:Stochastic Optimizer') |
|
|
11 |
@pytest.mark.parametrize('cls', |
|
|
12 |
[classifiers.svm(probability=True), |
|
|
13 |
classifiers.neuralnetwork(random_state=42)]) |
|
|
14 |
def test_classifiers(cls): |
|
|
15 |
# toy data |
|
|
16 |
X = np.concatenate((np.zeros((5, 2)), np.ones((5, 2)))) |
|
|
17 |
Y = np.concatenate((np.ones(5), np.zeros(5))) |
|
|
18 |
|
|
|
19 |
np.random.seed(42) |
|
|
20 |
|
|
|
21 |
cls.fit(X, Y) |
|
|
22 |
|
|
|
23 |
assert_array_equal(cls.predict(X), Y) |
|
|
24 |
assert cls.score(X, Y) == 1.0 |
|
|
25 |
|
|
|
26 |
prob = cls.predict_proba(X) |
|
|
27 |
assert_array_almost_equal(prob, [[0, 1]] * 5 + [[1, 0]] * 5, decimal=1) |
|
|
28 |
log_prob = cls.predict_log_proba(X) |
|
|
29 |
assert_array_almost_equal(np.log(prob), log_prob) |
|
|
30 |
|
|
|
31 |
pickled = pickle.dumps(cls) |
|
|
32 |
reloaded = pickle.loads(pickled) |
|
|
33 |
prob_reloaded = reloaded.predict_proba(X) |
|
|
34 |
assert_array_almost_equal(prob, prob_reloaded) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
@pytest.mark.parametrize('reg', |
|
|
38 |
[regressors.svm(C=10), |
|
|
39 |
regressors.randomforest(random_state=42), |
|
|
40 |
regressors.neuralnetwork(solver='lbfgs', |
|
|
41 |
random_state=42, |
|
|
42 |
hidden_layer_sizes=(20, 20)), |
|
|
43 |
regressors.mlr()]) |
|
|
44 |
def test_regressors(reg): |
|
|
45 |
X = np.vstack((np.arange(30, 10, -2, dtype='float64'), |
|
|
46 |
np.arange(100, 90, -1, dtype='float64'))).T |
|
|
47 |
|
|
|
48 |
Y = np.arange(10, dtype='float64') |
|
|
49 |
|
|
|
50 |
np.random.seed(42) |
|
|
51 |
|
|
|
52 |
reg.fit(X, Y) |
|
|
53 |
|
|
|
54 |
pred = reg.predict(X) |
|
|
55 |
assert (np.abs(pred.flatten() - Y) < 1).all() |
|
|
56 |
assert reg.score(X, Y) > 0.9 |
|
|
57 |
|
|
|
58 |
pickled = pickle.dumps(reg) |
|
|
59 |
reloaded = pickle.loads(pickled) |
|
|
60 |
pred_reloaded = reloaded.predict(X) |
|
|
61 |
assert_array_almost_equal(pred, pred_reloaded) |