Switch to unified view

a b/test/features/test_iob_feature.py
1
# Base Dependencies
2
# ----------------
3
import pytest
4
import numpy as np
5
from random import randrange
6
7
# Local Dependencies
8
# ------------------
9
from features.iob_feature import IOBFeature
10
from constants import DDI_IOB_TAGS, N2C2_IOB_TAGS
11
from models import RelationCollection, Entity, Relation
12
13
14
# Tests
15
# ------
16
def test_iob_init():
17
    iob = IOBFeature("n2c2")
18
    assert isinstance(iob, IOBFeature)
19
    assert iob.dataset == "n2c2"
20
    assert iob.iob_tags == N2C2_IOB_TAGS
21
    assert iob.padding_idx is None
22
23
    iob = IOBFeature("ddi")
24
    assert isinstance(iob, IOBFeature)
25
    assert iob.dataset == "ddi"
26
    assert iob.iob_tags == DDI_IOB_TAGS
27
    assert iob.padding_idx is None
28
29
30
def test_iob_init_unsupported_dataset_raises():
31
    with pytest.raises(ValueError):
32
        iob = IOBFeature("i2b2")
33
34
35
def test_iob_get_feature_names():
36
    iob = IOBFeature("n2c2")
37
    assert iob.get_feature_names() == ["IOB"]
38
39
40
def test_iob_fit(n2c2_small_collection):
41
    iob = IOBFeature("n2c2")
42
    iob = iob.fit(n2c2_small_collection)
43
    assert isinstance(iob, IOBFeature)
44
45
46
def test_iob_index():
47
    for padding_idx in [0, 1, 5]:
48
        iob = IOBFeature("n2c2", padding_idx=padding_idx)
49
50
        for tag in iob.iob_tags:
51
            idx = iob.iob_index(tag)
52
            assert idx != padding_idx
53
54
55
def test_iob_index_unknown_tag_raises():
56
    iob = IOBFeature("n2c2", padding_idx=0)
57
    with pytest.raises(Exception):
58
        idx = iob.iob_idx("B-UNKN")
59
60
61
def test_iob_create_iob_feature(n2c2_small_collection, relation):
62
63
    iob = IOBFeature("n2c2")
64
    iob_feature = iob.create_iob_feature(n2c2_small_collection)
65
    assert len(iob_feature) == len(n2c2_small_collection)
66
67
    for i in range(len(n2c2_small_collection)):
68
        e1_tokens = n2c2_small_collection.entities1_tokens[i]
69
        e2_tokens = n2c2_small_collection.entities2_tokens[i]
70
        sent_tokens = n2c2_small_collection.tokens[i]
71
72
        assert len(iob_feature[i]) == len(sent_tokens)
73
74
        if len(e1_tokens) > 1 and len(e2_tokens) > 1:
75
            unique = 5
76
        elif len(e1_tokens) > 1 or len(e2_tokens) > 1:
77
            unique = 4
78
        else:
79
            unique = 3
80
81
        assert len(np.unique(iob_feature[i])) == unique
82
83
    one_collection = RelationCollection(relation)
84
    iob_feature = iob.create_iob_feature(one_collection)
85
    assert len(iob_feature) == 1
86
    assert list(iob_feature[0]) == [0, 0, 0, 1, 0, 0, 13, 14, 0, 0, 0]
87