a b/tests/methods/test_crf_labeler.py
1
from deidentify.methods.crf.crf_labeler import (SentenceFilterCRF, Token,
2
                                                collapse_word_shape,
3
                                                has_unmatched_bracket,
4
                                                list_window,
5
                                                liu_feature_extractor, ngrams,
6
                                                word_shape)
7
8
9
def test_list_window():
10
    sent = ['a', 'b', 'w', 'c', 'd']
11
12
    assert list_window(sent, center=2, window=(0, 0)) == ['w']
13
    assert list_window(sent, center=2, window=(1, 1)) == ['b', 'w', 'c']
14
    assert list_window(sent, center=2, window=(2, 2)) == ['a', 'b', 'w', 'c', 'd']
15
    assert list_window(sent, center=2, window=(3, 3)) == [None, 'a', 'b', 'w', 'c', 'd', None]
16
    assert list_window(sent, center=0, window=(3, 3)) == [None, None, None, 'a', 'b', 'w', 'c']
17
    assert list_window(sent, center=0, window=(3, 0)) == [None, None, None, 'a']
18
19
20
def test_ngrams():
21
    tokens = ['a', 'b', 'w', 'c', 'd']
22
    assert ngrams(tokens, N=1) == [('a',), ('b',), ('w',), ('c',), ('d',)]
23
    assert ngrams(tokens, N=2) == [('a', 'b'), ('b', 'w'), ('w', 'c'), ('c', 'd')]
24
    assert ngrams(tokens, N=3) == [('a', 'b', 'w'), ('b', 'w', 'c'), ('w', 'c', 'd')]
25
26
27
def test_unmatched_bracket():
28
    sentence = [
29
        Token(text='De', pos_tag='DET', label='O', ner_tag=None),
30
        Token(text='patient', pos_tag='NOUN', label='O', ner_tag=None),
31
        Token(text='Ingmar', pos_tag='NOUN', label='O', ner_tag=None),
32
        Token(text='Koopal', pos_tag='PROPN', label='O', ner_tag=None),
33
        Token(text='(', pos_tag='PUNCT', label='O', ner_tag=None),
34
    ]
35
36
    assert has_unmatched_bracket(sentence)
37
    sentence.append(Token(text=')', pos_tag='PUNCT', label='O', ner_tag=None))
38
    assert not has_unmatched_bracket(sentence)
39
40
41
def test_word_shape():
42
    assert word_shape('IngmAr-12a') == 'AaaaAa-##a'
43
    assert word_shape('1234') == '####'
44
    assert word_shape('ömar') == 'aaaa'
45
46
47
def test_collapse_word_shape():
48
    assert collapse_word_shape('AaaaAa-##a') == 'AaAa-#a'
49
    assert collapse_word_shape('####') == '#'
50
51
52
def test_liu_feature_extractor():
53
    sentence = [
54
        Token(text='De', pos_tag='DET', label='O', ner_tag=None),
55
        Token(text='patient', pos_tag='NOUN', label='O', ner_tag=None),
56
        Token(text='Ingmar', pos_tag='NOUN', label='O', ner_tag='PER'),
57
        Token(text='Koopal', pos_tag='PROPN', label='O', ner_tag='PER'),
58
        Token(text='(', pos_tag='PUNCT', label='O', ner_tag=None),
59
        Token(text='t', pos_tag='NOUN', label='O', ner_tag=None),
60
        Token(text=':', pos_tag='PUNCT', label='O', ner_tag=None),
61
        Token(text='06', pos_tag='NUM', label='O', ner_tag=None),
62
        Token(text='-', pos_tag='PUNCT', label='O', ner_tag=None),
63
        Token(text='16769063', pos_tag='NUM', label='O', ner_tag=None),
64
        Token(text=')', pos_tag='PUNCT', label='O', ner_tag=None),
65
    ]
66
67
    assert liu_feature_extractor(sentence, 2) == {
68
        'bow[-2:2].uni.0': 'de',
69
        'bow[-2:2].uni.1': 'patient',
70
        'bow[-2:2].uni.2': 'ingmar',
71
        'bow[-2:2].uni.3': 'koopal',
72
        'bow[-2:2].uni.4': '(',
73
        'bow[-2:2].bi.0': 'de|patient',
74
        'bow[-2:2].bi.1':  'patient|ingmar',
75
        'bow[-2:2].bi.2':  'ingmar|koopal',
76
        'bow[-2:2].bi.3':  'koopal|(',
77
        'bow[-2:2].tri.0': 'de|patient|ingmar',
78
        'bow[-2:2].tri.1': 'patient|ingmar|koopal',
79
        'bow[-2:2].tri.2': 'ingmar|koopal|(',
80
        'pos[-2:2].uni.0': 'DET',
81
        'pos[-2:2].uni.1': 'NOUN',
82
        'pos[-2:2].uni.2': 'NOUN',
83
        'pos[-2:2].uni.3': 'PROPN',
84
        'pos[-2:2].uni.4': 'PUNCT',
85
        'pos[-2:2].bi.0':  'DET|NOUN',
86
        'pos[-2:2].bi.1':  'NOUN|NOUN',
87
        'pos[-2:2].bi.2':  'NOUN|PROPN',
88
        'pos[-2:2].bi.3':  'PROPN|PUNCT',
89
        'pos[-2:2].tri.0': 'DET|NOUN|NOUN',
90
        'pos[-2:2].tri.1': 'NOUN|NOUN|PROPN',
91
        'pos[-2:2].tri.2': 'NOUN|PROPN|PUNCT',
92
        'bowpos.w0p-1': 'ingmar|NOUN',
93
        'bowpos.w0p-1p0': 'ingmar|NOUN|NOUN',
94
        'bowpos.w0p-1p0p1': 'ingmar|NOUN|NOUN|PROPN',
95
        'bowpos.w0p-1p1': 'ingmar|NOUN|PROPN',
96
        'bowpos.w0p0': 'ingmar|NOUN',
97
        'bowpos.w0p0p1': 'ingmar|NOUN|PROPN',
98
        'bowpos.w0p1': 'ingmar|PROPN',
99
        'sent.end_mark': False,
100
        'sent.len(sent)': 11,
101
        'sent.has_unmatched_bracket': False,
102
        'prefix[:1]': 'i',
103
        'prefix[:2]': 'in',
104
        'prefix[:3]': 'ing',
105
        'prefix[:4]': 'ingm',
106
        'prefix[:5]': 'ingma',
107
        'suffix[-1:]': 'r',
108
        'suffix[-2:]': 'ar',
109
        'suffix[-3:]': 'mar',
110
        'suffix[-4:]': 'gmar',
111
        'suffix[-5:]': 'ngmar',
112
113
        'word.contains_digit': False,
114
        'word.has_digit_inside': False,
115
        'word.has_punct_inside': False,
116
        'word.has_upper_inside': False,
117
        'word.is_ascii': True,
118
        'word.isdigit()': False,
119
        'word.istitle()': True,
120
        'word.isupper()': False,
121
        'word.ner_tag': 'PER',
122
        'word.pos_tag': 'NOUN',
123
124
        'shape.long': 'Aaaaaa',
125
        'shape.short': 'Aa',
126
    }
127
128
129
def test_crf_labeler_marginals():
130
    sent1_features = [{'feat1': True, 'feat2': False}] * 4  # sentence will be ignored (see below)
131
    sent1_labels = ['O', 'B-Name', 'O', 'O']
132
    sent2_features = [{'feat1': False, 'feat2': True}] * 3
133
    sent2_labels = ['B-Date', 'I-Date', 'O']
134
135
    def ignore_sent(sent):
136
        return sent[0]['feat1'] == True
137
138
    X = [sent1_features, sent2_features]
139
    y = [sent1_labels, sent2_labels]
140
    crf = SentenceFilterCRF(ignored_label='O', ignore_sentence=ignore_sent)
141
    crf.fit(X, y)
142
143
    assert set(crf.classes_) == set(['O', 'B-Date', 'I-Date'])
144
145
    y_pred = crf.predict_marginals([sent1_features, sent2_features])
146
    # Should have two sentences
147
    assert len(y_pred) == 2
148
    assert len(y_pred[0]) == len(sent1_features), "Number of marginals should match len tokens"
149
    assert len(y_pred[1]) == len(sent2_features), "Number of marginals should match len tokens"
150
151
    # all tokens should have marginals for all classes
152
    for sent in y_pred:
153
        for token in sent:
154
            assert set(token.keys()) == set(crf.classes_)
155
156
    # First sentence (ignored) should marginal=1 for the ignored_label.
157
    ignored_marginals = {'O': 1, 'B-Date': 0, 'I-Date': 0}
158
    assert y_pred[0] == [ignored_marginals] * 4
159
    # Second sentence should have non-zero marginals for the other classes
160
    assert y_pred[1] != [ignored_marginals] * 3