[735bb5]: / test / features / test_bag_of_entities_feature.py

Download this file

43 lines (30 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Base Dependencies
# ----------------
import pytest
# Local Dependencies
# ------------------
from features import BagOfEntitiesFeature
# Constants
# ---------
from constants import DDI_ENTITY_TYPES, N2C2_ENTITY_TYPES
# Tests
# ------
def test_bag_of_entities_init():
boe = BagOfEntitiesFeature("n2c2")
assert boe.dataset == "n2c2"
assert len(boe.entity_types) > 0
def test_bag_of_entities_init_unknown_dataset_raises():
with pytest.raises(ValueError):
boe = BagOfEntitiesFeature("i2b2")
def test_bag_of_entities_get_feature_names():
boe = BagOfEntitiesFeature("ddi")
assert len(boe.get_feature_names()) == len(DDI_ENTITY_TYPES)
boe = BagOfEntitiesFeature("n2c2")
assert len(boe.get_feature_names()) == len(N2C2_ENTITY_TYPES)
def test_bag_of_entities_fit_transform(n2c2_small_collection):
boe = BagOfEntitiesFeature("n2c2")
feature = boe.fit_transform(n2c2_small_collection)
assert feature.shape[0] == len(n2c2_small_collection)
assert feature.shape[1] == len(boe.entity_types)
for i in range(feature.shape[0]):
feature[i,:].sum() == len(n2c2_small_collection.relations[i].middle_entities)