Switch to side-by-side view

--- a
+++ b/tests/methods/test_crf_labeler.py
@@ -0,0 +1,160 @@
+from deidentify.methods.crf.crf_labeler import (SentenceFilterCRF, Token,
+                                                collapse_word_shape,
+                                                has_unmatched_bracket,
+                                                list_window,
+                                                liu_feature_extractor, ngrams,
+                                                word_shape)
+
+
+def test_list_window():
+    sent = ['a', 'b', 'w', 'c', 'd']
+
+    assert list_window(sent, center=2, window=(0, 0)) == ['w']
+    assert list_window(sent, center=2, window=(1, 1)) == ['b', 'w', 'c']
+    assert list_window(sent, center=2, window=(2, 2)) == ['a', 'b', 'w', 'c', 'd']
+    assert list_window(sent, center=2, window=(3, 3)) == [None, 'a', 'b', 'w', 'c', 'd', None]
+    assert list_window(sent, center=0, window=(3, 3)) == [None, None, None, 'a', 'b', 'w', 'c']
+    assert list_window(sent, center=0, window=(3, 0)) == [None, None, None, 'a']
+
+
+def test_ngrams():
+    tokens = ['a', 'b', 'w', 'c', 'd']
+    assert ngrams(tokens, N=1) == [('a',), ('b',), ('w',), ('c',), ('d',)]
+    assert ngrams(tokens, N=2) == [('a', 'b'), ('b', 'w'), ('w', 'c'), ('c', 'd')]
+    assert ngrams(tokens, N=3) == [('a', 'b', 'w'), ('b', 'w', 'c'), ('w', 'c', 'd')]
+
+
+def test_unmatched_bracket():
+    sentence = [
+        Token(text='De', pos_tag='DET', label='O', ner_tag=None),
+        Token(text='patient', pos_tag='NOUN', label='O', ner_tag=None),
+        Token(text='Ingmar', pos_tag='NOUN', label='O', ner_tag=None),
+        Token(text='Koopal', pos_tag='PROPN', label='O', ner_tag=None),
+        Token(text='(', pos_tag='PUNCT', label='O', ner_tag=None),
+    ]
+
+    assert has_unmatched_bracket(sentence)
+    sentence.append(Token(text=')', pos_tag='PUNCT', label='O', ner_tag=None))
+    assert not has_unmatched_bracket(sentence)
+
+
+def test_word_shape():
+    assert word_shape('IngmAr-12a') == 'AaaaAa-##a'
+    assert word_shape('1234') == '####'
+    assert word_shape('ömar') == 'aaaa'
+
+
+def test_collapse_word_shape():
+    assert collapse_word_shape('AaaaAa-##a') == 'AaAa-#a'
+    assert collapse_word_shape('####') == '#'
+
+
+def test_liu_feature_extractor():
+    sentence = [
+        Token(text='De', pos_tag='DET', label='O', ner_tag=None),
+        Token(text='patient', pos_tag='NOUN', label='O', ner_tag=None),
+        Token(text='Ingmar', pos_tag='NOUN', label='O', ner_tag='PER'),
+        Token(text='Koopal', pos_tag='PROPN', label='O', ner_tag='PER'),
+        Token(text='(', pos_tag='PUNCT', label='O', ner_tag=None),
+        Token(text='t', pos_tag='NOUN', label='O', ner_tag=None),
+        Token(text=':', pos_tag='PUNCT', label='O', ner_tag=None),
+        Token(text='06', pos_tag='NUM', label='O', ner_tag=None),
+        Token(text='-', pos_tag='PUNCT', label='O', ner_tag=None),
+        Token(text='16769063', pos_tag='NUM', label='O', ner_tag=None),
+        Token(text=')', pos_tag='PUNCT', label='O', ner_tag=None),
+    ]
+
+    assert liu_feature_extractor(sentence, 2) == {
+        'bow[-2:2].uni.0': 'de',
+        'bow[-2:2].uni.1': 'patient',
+        'bow[-2:2].uni.2': 'ingmar',
+        'bow[-2:2].uni.3': 'koopal',
+        'bow[-2:2].uni.4': '(',
+        'bow[-2:2].bi.0': 'de|patient',
+        'bow[-2:2].bi.1':  'patient|ingmar',
+        'bow[-2:2].bi.2':  'ingmar|koopal',
+        'bow[-2:2].bi.3':  'koopal|(',
+        'bow[-2:2].tri.0': 'de|patient|ingmar',
+        'bow[-2:2].tri.1': 'patient|ingmar|koopal',
+        'bow[-2:2].tri.2': 'ingmar|koopal|(',
+        'pos[-2:2].uni.0': 'DET',
+        'pos[-2:2].uni.1': 'NOUN',
+        'pos[-2:2].uni.2': 'NOUN',
+        'pos[-2:2].uni.3': 'PROPN',
+        'pos[-2:2].uni.4': 'PUNCT',
+        'pos[-2:2].bi.0':  'DET|NOUN',
+        'pos[-2:2].bi.1':  'NOUN|NOUN',
+        'pos[-2:2].bi.2':  'NOUN|PROPN',
+        'pos[-2:2].bi.3':  'PROPN|PUNCT',
+        'pos[-2:2].tri.0': 'DET|NOUN|NOUN',
+        'pos[-2:2].tri.1': 'NOUN|NOUN|PROPN',
+        'pos[-2:2].tri.2': 'NOUN|PROPN|PUNCT',
+        'bowpos.w0p-1': 'ingmar|NOUN',
+        'bowpos.w0p-1p0': 'ingmar|NOUN|NOUN',
+        'bowpos.w0p-1p0p1': 'ingmar|NOUN|NOUN|PROPN',
+        'bowpos.w0p-1p1': 'ingmar|NOUN|PROPN',
+        'bowpos.w0p0': 'ingmar|NOUN',
+        'bowpos.w0p0p1': 'ingmar|NOUN|PROPN',
+        'bowpos.w0p1': 'ingmar|PROPN',
+        'sent.end_mark': False,
+        'sent.len(sent)': 11,
+        'sent.has_unmatched_bracket': False,
+        'prefix[:1]': 'i',
+        'prefix[:2]': 'in',
+        'prefix[:3]': 'ing',
+        'prefix[:4]': 'ingm',
+        'prefix[:5]': 'ingma',
+        'suffix[-1:]': 'r',
+        'suffix[-2:]': 'ar',
+        'suffix[-3:]': 'mar',
+        'suffix[-4:]': 'gmar',
+        'suffix[-5:]': 'ngmar',
+
+        'word.contains_digit': False,
+        'word.has_digit_inside': False,
+        'word.has_punct_inside': False,
+        'word.has_upper_inside': False,
+        'word.is_ascii': True,
+        'word.isdigit()': False,
+        'word.istitle()': True,
+        'word.isupper()': False,
+        'word.ner_tag': 'PER',
+        'word.pos_tag': 'NOUN',
+
+        'shape.long': 'Aaaaaa',
+        'shape.short': 'Aa',
+    }
+
+
+def test_crf_labeler_marginals():
+    sent1_features = [{'feat1': True, 'feat2': False}] * 4  # sentence will be ignored (see below)
+    sent1_labels = ['O', 'B-Name', 'O', 'O']
+    sent2_features = [{'feat1': False, 'feat2': True}] * 3
+    sent2_labels = ['B-Date', 'I-Date', 'O']
+
+    def ignore_sent(sent):
+        return sent[0]['feat1'] == True
+
+    X = [sent1_features, sent2_features]
+    y = [sent1_labels, sent2_labels]
+    crf = SentenceFilterCRF(ignored_label='O', ignore_sentence=ignore_sent)
+    crf.fit(X, y)
+
+    assert set(crf.classes_) == set(['O', 'B-Date', 'I-Date'])
+
+    y_pred = crf.predict_marginals([sent1_features, sent2_features])
+    # Should have two sentences
+    assert len(y_pred) == 2
+    assert len(y_pred[0]) == len(sent1_features), "Number of marginals should match len tokens"
+    assert len(y_pred[1]) == len(sent2_features), "Number of marginals should match len tokens"
+
+    # all tokens should have marginals for all classes
+    for sent in y_pred:
+        for token in sent:
+            assert set(token.keys()) == set(crf.classes_)
+
+    # First sentence (ignored) should marginal=1 for the ignored_label.
+    ignored_marginals = {'O': 1, 'B-Date': 0, 'I-Date': 0}
+    assert y_pred[0] == [ignored_marginals] * 4
+    # Second sentence should have non-zero marginals for the other classes
+    assert y_pred[1] != [ignored_marginals] * 3