|
a |
|
b/test/features/test_position_feature.py |
|
|
1 |
# Base Dependencies |
|
|
2 |
# ---------------- |
|
|
3 |
import pytest |
|
|
4 |
import numpy as np |
|
|
5 |
from random import randrange |
|
|
6 |
|
|
|
7 |
# Local Dependencies |
|
|
8 |
# ------------------ |
|
|
9 |
from models import RelationCollection |
|
|
10 |
from features.position_feature import PositionFeature |
|
|
11 |
|
|
|
12 |
# Constants |
|
|
13 |
# --------- |
|
|
14 |
from constants import ( |
|
|
15 |
N2C2_ATTR_ENTITY_CANDIDATES, |
|
|
16 |
DDI_ATTR_ENTITY_CANDIDATES, |
|
|
17 |
) |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
# Tests |
|
|
21 |
# ------ |
|
|
22 |
def test_position_init(): |
|
|
23 |
position = PositionFeature("n2c2") |
|
|
24 |
assert isinstance(position, PositionFeature) |
|
|
25 |
assert position.dataset == "n2c2" |
|
|
26 |
assert position.attr_entity_candidates == N2C2_ATTR_ENTITY_CANDIDATES |
|
|
27 |
|
|
|
28 |
position = PositionFeature("ddi") |
|
|
29 |
assert isinstance(position, PositionFeature) |
|
|
30 |
assert position.dataset == "ddi" |
|
|
31 |
assert position.attr_entity_candidates == DDI_ATTR_ENTITY_CANDIDATES |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def test_position_init_unsupported_dataset_raises(): |
|
|
35 |
with pytest.raises(ValueError): |
|
|
36 |
position = PositionFeature("i2b2") |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def test_position_get_feature_names(): |
|
|
40 |
position = PositionFeature("n2c2") |
|
|
41 |
assert len(position.get_feature_names()) == 2 |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def test_position_fit(n2c2_small_collection): |
|
|
45 |
position = PositionFeature("n2c2") |
|
|
46 |
position = position.fit(n2c2_small_collection) |
|
|
47 |
assert isinstance(position, PositionFeature) |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def test_position_create_position_feature_(n2c2_small_collection, relation): |
|
|
51 |
|
|
|
52 |
position = PositionFeature("n2c2") |
|
|
53 |
position_feature = position.create_position_feature(n2c2_small_collection) |
|
|
54 |
assert position_feature.shape == (len(n2c2_small_collection), 2) |
|
|
55 |
|
|
|
56 |
for i in range(len(n2c2_small_collection)): |
|
|
57 |
assert (position_feature[i][0] == 0 and position_feature[i][1] >= 0) or ( |
|
|
58 |
position_feature[i][0] <= 0 and position_feature[i][1] == 0 |
|
|
59 |
) |
|
|
60 |
|
|
|
61 |
one_collection = RelationCollection(relation) |
|
|
62 |
|
|
|
63 |
position_feature = position.create_position_feature(one_collection) |
|
|
64 |
|
|
|
65 |
assert (position_feature == np.array([[-1, 0]])).all() |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
|