|
a |
|
b/tests/methods/test_crf_labeler.py |
|
|
1 |
from deidentify.methods.crf.crf_labeler import (SentenceFilterCRF, Token, |
|
|
2 |
collapse_word_shape, |
|
|
3 |
has_unmatched_bracket, |
|
|
4 |
list_window, |
|
|
5 |
liu_feature_extractor, ngrams, |
|
|
6 |
word_shape) |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def test_list_window(): |
|
|
10 |
sent = ['a', 'b', 'w', 'c', 'd'] |
|
|
11 |
|
|
|
12 |
assert list_window(sent, center=2, window=(0, 0)) == ['w'] |
|
|
13 |
assert list_window(sent, center=2, window=(1, 1)) == ['b', 'w', 'c'] |
|
|
14 |
assert list_window(sent, center=2, window=(2, 2)) == ['a', 'b', 'w', 'c', 'd'] |
|
|
15 |
assert list_window(sent, center=2, window=(3, 3)) == [None, 'a', 'b', 'w', 'c', 'd', None] |
|
|
16 |
assert list_window(sent, center=0, window=(3, 3)) == [None, None, None, 'a', 'b', 'w', 'c'] |
|
|
17 |
assert list_window(sent, center=0, window=(3, 0)) == [None, None, None, 'a'] |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def test_ngrams(): |
|
|
21 |
tokens = ['a', 'b', 'w', 'c', 'd'] |
|
|
22 |
assert ngrams(tokens, N=1) == [('a',), ('b',), ('w',), ('c',), ('d',)] |
|
|
23 |
assert ngrams(tokens, N=2) == [('a', 'b'), ('b', 'w'), ('w', 'c'), ('c', 'd')] |
|
|
24 |
assert ngrams(tokens, N=3) == [('a', 'b', 'w'), ('b', 'w', 'c'), ('w', 'c', 'd')] |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
def test_unmatched_bracket(): |
|
|
28 |
sentence = [ |
|
|
29 |
Token(text='De', pos_tag='DET', label='O', ner_tag=None), |
|
|
30 |
Token(text='patient', pos_tag='NOUN', label='O', ner_tag=None), |
|
|
31 |
Token(text='Ingmar', pos_tag='NOUN', label='O', ner_tag=None), |
|
|
32 |
Token(text='Koopal', pos_tag='PROPN', label='O', ner_tag=None), |
|
|
33 |
Token(text='(', pos_tag='PUNCT', label='O', ner_tag=None), |
|
|
34 |
] |
|
|
35 |
|
|
|
36 |
assert has_unmatched_bracket(sentence) |
|
|
37 |
sentence.append(Token(text=')', pos_tag='PUNCT', label='O', ner_tag=None)) |
|
|
38 |
assert not has_unmatched_bracket(sentence) |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def test_word_shape(): |
|
|
42 |
assert word_shape('IngmAr-12a') == 'AaaaAa-##a' |
|
|
43 |
assert word_shape('1234') == '####' |
|
|
44 |
assert word_shape('ömar') == 'aaaa' |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def test_collapse_word_shape(): |
|
|
48 |
assert collapse_word_shape('AaaaAa-##a') == 'AaAa-#a' |
|
|
49 |
assert collapse_word_shape('####') == '#' |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def test_liu_feature_extractor(): |
|
|
53 |
sentence = [ |
|
|
54 |
Token(text='De', pos_tag='DET', label='O', ner_tag=None), |
|
|
55 |
Token(text='patient', pos_tag='NOUN', label='O', ner_tag=None), |
|
|
56 |
Token(text='Ingmar', pos_tag='NOUN', label='O', ner_tag='PER'), |
|
|
57 |
Token(text='Koopal', pos_tag='PROPN', label='O', ner_tag='PER'), |
|
|
58 |
Token(text='(', pos_tag='PUNCT', label='O', ner_tag=None), |
|
|
59 |
Token(text='t', pos_tag='NOUN', label='O', ner_tag=None), |
|
|
60 |
Token(text=':', pos_tag='PUNCT', label='O', ner_tag=None), |
|
|
61 |
Token(text='06', pos_tag='NUM', label='O', ner_tag=None), |
|
|
62 |
Token(text='-', pos_tag='PUNCT', label='O', ner_tag=None), |
|
|
63 |
Token(text='16769063', pos_tag='NUM', label='O', ner_tag=None), |
|
|
64 |
Token(text=')', pos_tag='PUNCT', label='O', ner_tag=None), |
|
|
65 |
] |
|
|
66 |
|
|
|
67 |
assert liu_feature_extractor(sentence, 2) == { |
|
|
68 |
'bow[-2:2].uni.0': 'de', |
|
|
69 |
'bow[-2:2].uni.1': 'patient', |
|
|
70 |
'bow[-2:2].uni.2': 'ingmar', |
|
|
71 |
'bow[-2:2].uni.3': 'koopal', |
|
|
72 |
'bow[-2:2].uni.4': '(', |
|
|
73 |
'bow[-2:2].bi.0': 'de|patient', |
|
|
74 |
'bow[-2:2].bi.1': 'patient|ingmar', |
|
|
75 |
'bow[-2:2].bi.2': 'ingmar|koopal', |
|
|
76 |
'bow[-2:2].bi.3': 'koopal|(', |
|
|
77 |
'bow[-2:2].tri.0': 'de|patient|ingmar', |
|
|
78 |
'bow[-2:2].tri.1': 'patient|ingmar|koopal', |
|
|
79 |
'bow[-2:2].tri.2': 'ingmar|koopal|(', |
|
|
80 |
'pos[-2:2].uni.0': 'DET', |
|
|
81 |
'pos[-2:2].uni.1': 'NOUN', |
|
|
82 |
'pos[-2:2].uni.2': 'NOUN', |
|
|
83 |
'pos[-2:2].uni.3': 'PROPN', |
|
|
84 |
'pos[-2:2].uni.4': 'PUNCT', |
|
|
85 |
'pos[-2:2].bi.0': 'DET|NOUN', |
|
|
86 |
'pos[-2:2].bi.1': 'NOUN|NOUN', |
|
|
87 |
'pos[-2:2].bi.2': 'NOUN|PROPN', |
|
|
88 |
'pos[-2:2].bi.3': 'PROPN|PUNCT', |
|
|
89 |
'pos[-2:2].tri.0': 'DET|NOUN|NOUN', |
|
|
90 |
'pos[-2:2].tri.1': 'NOUN|NOUN|PROPN', |
|
|
91 |
'pos[-2:2].tri.2': 'NOUN|PROPN|PUNCT', |
|
|
92 |
'bowpos.w0p-1': 'ingmar|NOUN', |
|
|
93 |
'bowpos.w0p-1p0': 'ingmar|NOUN|NOUN', |
|
|
94 |
'bowpos.w0p-1p0p1': 'ingmar|NOUN|NOUN|PROPN', |
|
|
95 |
'bowpos.w0p-1p1': 'ingmar|NOUN|PROPN', |
|
|
96 |
'bowpos.w0p0': 'ingmar|NOUN', |
|
|
97 |
'bowpos.w0p0p1': 'ingmar|NOUN|PROPN', |
|
|
98 |
'bowpos.w0p1': 'ingmar|PROPN', |
|
|
99 |
'sent.end_mark': False, |
|
|
100 |
'sent.len(sent)': 11, |
|
|
101 |
'sent.has_unmatched_bracket': False, |
|
|
102 |
'prefix[:1]': 'i', |
|
|
103 |
'prefix[:2]': 'in', |
|
|
104 |
'prefix[:3]': 'ing', |
|
|
105 |
'prefix[:4]': 'ingm', |
|
|
106 |
'prefix[:5]': 'ingma', |
|
|
107 |
'suffix[-1:]': 'r', |
|
|
108 |
'suffix[-2:]': 'ar', |
|
|
109 |
'suffix[-3:]': 'mar', |
|
|
110 |
'suffix[-4:]': 'gmar', |
|
|
111 |
'suffix[-5:]': 'ngmar', |
|
|
112 |
|
|
|
113 |
'word.contains_digit': False, |
|
|
114 |
'word.has_digit_inside': False, |
|
|
115 |
'word.has_punct_inside': False, |
|
|
116 |
'word.has_upper_inside': False, |
|
|
117 |
'word.is_ascii': True, |
|
|
118 |
'word.isdigit()': False, |
|
|
119 |
'word.istitle()': True, |
|
|
120 |
'word.isupper()': False, |
|
|
121 |
'word.ner_tag': 'PER', |
|
|
122 |
'word.pos_tag': 'NOUN', |
|
|
123 |
|
|
|
124 |
'shape.long': 'Aaaaaa', |
|
|
125 |
'shape.short': 'Aa', |
|
|
126 |
} |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def test_crf_labeler_marginals(): |
|
|
130 |
sent1_features = [{'feat1': True, 'feat2': False}] * 4 # sentence will be ignored (see below) |
|
|
131 |
sent1_labels = ['O', 'B-Name', 'O', 'O'] |
|
|
132 |
sent2_features = [{'feat1': False, 'feat2': True}] * 3 |
|
|
133 |
sent2_labels = ['B-Date', 'I-Date', 'O'] |
|
|
134 |
|
|
|
135 |
def ignore_sent(sent): |
|
|
136 |
return sent[0]['feat1'] == True |
|
|
137 |
|
|
|
138 |
X = [sent1_features, sent2_features] |
|
|
139 |
y = [sent1_labels, sent2_labels] |
|
|
140 |
crf = SentenceFilterCRF(ignored_label='O', ignore_sentence=ignore_sent) |
|
|
141 |
crf.fit(X, y) |
|
|
142 |
|
|
|
143 |
assert set(crf.classes_) == set(['O', 'B-Date', 'I-Date']) |
|
|
144 |
|
|
|
145 |
y_pred = crf.predict_marginals([sent1_features, sent2_features]) |
|
|
146 |
# Should have two sentences |
|
|
147 |
assert len(y_pred) == 2 |
|
|
148 |
assert len(y_pred[0]) == len(sent1_features), "Number of marginals should match len tokens" |
|
|
149 |
assert len(y_pred[1]) == len(sent2_features), "Number of marginals should match len tokens" |
|
|
150 |
|
|
|
151 |
# all tokens should have marginals for all classes |
|
|
152 |
for sent in y_pred: |
|
|
153 |
for token in sent: |
|
|
154 |
assert set(token.keys()) == set(crf.classes_) |
|
|
155 |
|
|
|
156 |
# First sentence (ignored) should marginal=1 for the ignored_label. |
|
|
157 |
ignored_marginals = {'O': 1, 'B-Date': 0, 'I-Date': 0} |
|
|
158 |
assert y_pred[0] == [ignored_marginals] * 4 |
|
|
159 |
# Second sentence should have non-zero marginals for the other classes |
|
|
160 |
assert y_pred[1] != [ignored_marginals] * 3 |