Switch to side-by-side view

--- a
+++ b/deidentify/methods/crf/crf_labeler.py
@@ -0,0 +1,376 @@
+"""CRF training utilities.
+
+Two feature sets are provided:
+
+   1. Default feature set by sklearn_crfsuite
+   2. Liu et al. (2015) features used in de-identification shared task.
+
+Features are encoded in python-crfsuite format:
+https://python-crfsuite.readthedocs.io/en/latest/pycrfsuite.html#pycrfsuite.ItemSequence
+"""
+
+import re
+import string
+from typing import Callable, Dict, List, Tuple
+
+import sklearn_crfsuite
+from tqdm import tqdm
+from unidecode import unidecode
+
+from deidentify.methods.tagging_utils import Token
+
+NEWLINE_REGEX = re.compile(r'\n')
+SPACE_REGEX = re.compile(r'\s')
+
+
+def sent2features(sent: List[Token],
+                  feature_extractor: Callable[[List[Token], int], Dict]) -> List[Dict]:
+    """Convert a sentence to features in python-crfsuite format.
+
+    python-crfsuite can't handle feature values that contain whitespace or newline characters. These
+    characters are replaced with a special #SPACE and #NEWLINE token.
+
+    See issues:
+    https://github.com/scrapinghub/python-crfsuite/issues/14
+    https://github.com/scrapinghub/python-crfsuite/issues/71
+
+    Parameters
+    ----------
+    sent : List[Token]
+        A sentence constituded of a list of tokens.
+    feature_extractor : Callable[[List[Token], int], Dict]
+        Callable that represents a token at position `i: int` as a feature dict.
+
+    Returns
+    -------
+    sent_features : List[Dict]
+        List of feature dicts per token. `len(sent_features) == len(sent)`
+
+    """
+    sent_features = []
+
+    for i in range(len(sent)):
+        token_features = feature_extractor(sent, i)
+
+        for feature_name, value in token_features.items():
+            if not isinstance(value, str):
+                continue
+
+            value = NEWLINE_REGEX.sub('#NEWLINE', value)
+            value = SPACE_REGEX.sub('#SPACE', value)
+            token_features[feature_name] = value
+
+        sent_features.append(token_features)
+
+    return sent_features
+
+
+def sent2labels(sent):
+    return [token.label for token in sent]
+
+
+def sents_to_features_and_labels(sents, feature_extractor):
+    X = [sent2features(s, feature_extractor) for s in sents]
+    y = [sent2labels(s) for s in sents]
+    return X, y
+
+
+class SentenceFilterCRF(sklearn_crfsuite.CRF):
+    """Custom CRF implementation that allows to ignore entire sentences during training/prediction
+    time. A default label will be assigned to all tokens within that sentence.
+
+    The confidence of predictions within ignored sentences is set to 1.
+    """
+
+    # pylint: disable=too-many-arguments
+    # pylint: disable=R0914
+    # scikit-learn estimators explicitly have to mention keyword arguments (no *args, **kwargs)
+    def __init__(self,
+                 ignore_sentence,
+                 ignored_label,
+                 algorithm=None,
+                 min_freq=None,
+                 all_possible_states=None,
+                 all_possible_transitions=None,
+                 c1=None,
+                 c2=None,
+                 max_iterations=None,
+                 num_memories=None,
+                 epsilon=None,
+                 period=None,
+                 delta=None,
+                 linesearch=None,
+                 max_linesearch=None,
+                 calibration_eta=None,
+                 calibration_rate=None,
+                 calibration_samples=None,
+                 calibration_candidates=None,
+                 calibration_max_trials=None,
+                 pa_type=None,
+                 c=None,
+                 error_sensitive=None,
+                 averaging=None,
+                 variance=None,
+                 gamma=None,
+                 verbose=False,
+                 model_filename=None,
+                 keep_tempfiles=False,
+                 trainer_cls=None):
+        self.ignore_sentence = ignore_sentence
+        self.ignored_label = ignored_label
+
+        super(SentenceFilterCRF, self).__init__(
+            algorithm=algorithm,
+            min_freq=min_freq,
+            all_possible_states=all_possible_states,
+            all_possible_transitions=all_possible_transitions,
+            c1=c1,
+            c2=c2,
+            max_iterations=max_iterations,
+            num_memories=num_memories,
+            epsilon=epsilon,
+            period=period,
+            delta=delta,
+            linesearch=linesearch,
+            max_linesearch=max_linesearch,
+            calibration_eta=calibration_eta,
+            calibration_rate=calibration_rate,
+            calibration_samples=calibration_samples,
+            calibration_candidates=calibration_candidates,
+            calibration_max_trials=calibration_max_trials,
+            pa_type=pa_type,
+            c=c,
+            error_sensitive=error_sensitive,
+            averaging=averaging,
+            variance=variance,
+            gamma=gamma,
+            verbose=verbose,
+            model_filename=model_filename,
+            keep_tempfiles=keep_tempfiles,
+            trainer_cls=trainer_cls)
+
+    def _filter_xy(self, X, y):
+        return zip(*filter(lambda xy: not self.ignore_sentence(xy[0]), zip(X, y)))
+
+    def fit(self, X, y, X_dev=None, y_dev=None):
+        X, y = self._filter_xy(X, y)
+
+        if X_dev is not None:
+            X_dev, y_dev = self._filter_xy(X_dev, y_dev)
+
+        return super(SentenceFilterCRF, self).fit(X, y, X_dev=X_dev, y_dev=y_dev)
+
+    def predict(self, X, verbose=False):
+        X = tqdm(X, disable=not verbose, desc='Tag sentences')
+        return super().predict(X)
+
+    def predict_single(self, xseq):
+        if self.ignore_sentence(xseq):
+            return [self.ignored_label] * len(xseq)
+
+        return super().predict_single(xseq)
+
+    def predict_marginals_single(self, xseq):
+        if self.ignore_sentence(xseq):
+            ignored_marginals = {c: 0 for c in self.classes_}
+            ignored_marginals[self.ignored_label] = 1
+            return [ignored_marginals] * len(xseq)
+
+        return super().predict_marginals_single(xseq)
+
+
+def sklearn_crfsuite_feature_extractor(sent, i):
+    """
+    Taken from:
+    https://sklearn-crfsuite.readthedocs.io/en/latest/tutorial.html
+    """
+    word = sent[i].text
+    pos_tag = sent[i].pos_tag
+
+    features = {
+        'bias': 1.0,
+        'word.lower()': word.lower(),
+        'word[-3:]': word[-3:],
+        'word[-2:]': word[-2:],
+        'word.isupper()': word.isupper(),
+        'word.istitle()': word.istitle(),
+        'word.isdigit()': word.isdigit(),
+        'pos_tag': pos_tag,
+        'pos_tag[:2]': pos_tag[:2],
+    }
+    if i > 0:
+        word1 = sent[i - 1].text
+        pos_tag1 = sent[i - 1].pos_tag
+        features.update({
+            '-1:word.lower()': word1.lower(),
+            '-1:word.istitle()': word1.istitle(),
+            '-1:word.isupper()': word1.isupper(),
+            '-1:pos_tag': pos_tag1,
+            '-1:pos_tag[:2]': pos_tag1[:2],
+        })
+    else:
+        features['BOS'] = True
+
+    if i < len(sent) - 1:
+        word1 = sent[i + 1].text
+        pos_tag1 = sent[i + 1].pos_tag
+        features.update({
+            '+1:word.lower()': word1.lower(),
+            '+1:word.istitle()': word1.istitle(),
+            '+1:word.isupper()': word1.isupper(),
+            '+1:pos_tag': pos_tag1,
+            '+1:pos_tag[:2]': pos_tag1[:2],
+        })
+    else:
+        features['EOS'] = True
+
+    return features
+
+
+def liu_feature_extractor(sent, i):
+    """Reproduces the features used by Liu et al. (2015).
+
+    Does not include word representation (word2vec, brown clusters) and gazetteer features.
+
+    Reference:
+    Liu, Z., et al. (2015). Automatic de-identification of electronic medical records using
+    token-level and character-level conditional random fields. Journal of Biomedical Informatics,
+    58, S47–S52. https://doi.org/10.1016/J.JBI.2015.06.009
+    """
+    token = sent[i]
+
+    null_token = Token(text='<PAD>', pos_tag='<PAD>', label='', ner_tag=None)
+    sent_window = list_window(sent, center=i, window=(2, 2), oob_item=null_token)
+    token_window = [t.text.lower() for t in sent_window]
+    pos_window = [t.pos_tag for t in sent_window]
+    text_lower = token.text.lower()
+
+    features = {}
+    features.update(_ngram_feature_group(token_window, N=1, group_name='bow[-2:2].uni'))
+    features.update(_ngram_feature_group(token_window, N=2, group_name='bow[-2:2].bi'))
+    features.update(_ngram_feature_group(token_window, N=3, group_name='bow[-2:2].tri'))
+
+    features.update(_ngram_feature_group(pos_window, N=1, group_name='pos[-2:2].uni'))
+    features.update(_ngram_feature_group(pos_window, N=2, group_name='pos[-2:2].bi'))
+    features.update(_ngram_feature_group(pos_window, N=3, group_name='pos[-2:2].tri'))
+
+    sent_window = list_window(sent, center=i, window=(1, 1), oob_item=null_token)
+    pos_window = [t.pos_tag for t in sent_window]
+    sep = join_features
+    features['bowpos.w0p-1'] = sep((text_lower, pos_window[0]))
+    features['bowpos.w0p0'] = sep((text_lower, pos_window[1]))
+    features['bowpos.w0p1'] = sep((text_lower, pos_window[2]))
+    features['bowpos.w0p-1p0'] = sep((text_lower, pos_window[0], pos_window[1]))
+    features['bowpos.w0p0p1'] = sep((text_lower, pos_window[1], pos_window[2]))
+    features['bowpos.w0p-1p1'] = sep((text_lower, pos_window[0], pos_window[2]))
+    features['bowpos.w0p-1p0p1'] = sep((text_lower, pos_window[0], pos_window[1], pos_window[2]))
+
+    features['sent.len(sent)'] = len(sent)
+    features['sent.end_mark'] = sent[-1].text.strip() in ['!', '?', '.']
+    features['sent.has_unmatched_bracket'] = has_unmatched_bracket(sent)
+
+    for j in range(1, 6):
+        features['suffix[-{}:]'.format(j)] = text_lower[-j:]
+        features['prefix[:{}]'.format(j)] = text_lower[:j]
+
+    features['word.isupper()'] = token.text.isupper()
+    features['word.istitle()'] = token.text.istitle()
+    features['word.isdigit()'] = token.text.isdigit()
+    features['word.contains_digit'] = any(c.isdigit() for c in token.text)
+    features['word.has_upper_inside'] = any(c.isupper() for c in token.text[1:])
+    features['word.has_punct_inside'] = any(c in string.punctuation for c in token.text[1:])
+    features['word.has_digit_inside'] = any(c.isdigit() for c in token.text[1:])
+    features['word.is_ascii'] = all(ord(c) < 128 for c in token.text)
+    features['word.ner_tag'] = token.ner_tag
+    features['word.pos_tag'] = token.pos_tag
+
+    shape = word_shape(token.text)
+    features['shape.long'] = shape
+    features['shape.short'] = collapse_word_shape(shape)
+
+    return features
+
+
+def join_features(feature_list):
+    return '|'.join(feature_list)
+
+
+def ngrams(tokens, N):
+    return [tuple(tokens[i:i + N]) for i in range(len(tokens) - N + 1)]
+
+
+def list_window(sent: List, center: int, window: Tuple[int, int], oob_item=None) -> List:
+    """Get a window of tokens within a sentence.
+
+    Parameters
+    ----------
+    sent : List
+        A list of tokens.
+    center : int
+        The index acting as center of the window.
+    window : Tuple[int, int]
+        The window width. `window[0]` is elements before center, `window[1]` is elements after
+        center. Interval is closed.
+    oob_item : type
+        The item to return if window indexes are out of bounds of `sent`.
+
+    Returns
+    -------
+    tokens : List
+        The tokens within the given window.
+
+    """
+    tokens = []
+    for i in range(center - window[0], center + window[1] + 1):
+        if i < 0:
+            tokens.append(oob_item)
+        elif i >= len(sent):
+            tokens.append(oob_item)
+        else:
+            tokens.append(sent[i])
+    return tokens
+
+
+def _ngram_feature_group(tokens, N, group_name, sep=join_features):
+    features = {}
+    token_ngrams = ngrams(tokens, N)
+    for j, item in enumerate(token_ngrams):
+        features['{}.{}'.format(group_name, j)] = sep(item)
+    return features
+
+
+def has_unmatched_bracket(sent):
+    n_open = 0
+
+    for token in sent:
+        if token.text == '(':
+            n_open += 1
+        elif token.text == ')':
+            n_open -= 1
+
+    return n_open > 0
+
+
+def word_shape(token):
+    shape = ''
+    for c in unidecode(token):
+        if c in string.ascii_lowercase:
+            shape += 'a'
+        elif c in string.ascii_uppercase:
+            shape += 'A'
+        elif c in string.digits:
+            shape += '#'
+        else:
+            shape += '-'
+    return shape
+
+
+def collapse_word_shape(shape):
+    collapsed = ''
+    current = None
+    for c in shape:
+        if c == current:
+            continue
+        collapsed += c
+        current = c
+    return collapsed