Diff of /keras_bert/tokenizer.py [000000] .. [51873b]

Switch to side-by-side view

--- a
+++ b/keras_bert/tokenizer.py
@@ -0,0 +1,149 @@
+import unicodedata
+from .bert import TOKEN_CLS, TOKEN_SEP, TOKEN_UNK
+
+
+class Tokenizer(object):
+
+    def __init__(self,
+                 token_dict,
+                 token_cls=TOKEN_CLS,
+                 token_sep=TOKEN_SEP,
+                 token_unk=TOKEN_UNK,
+                 pad_index=0,
+                 cased=False):
+        """Initialize tokenizer.
+
+        :param token_dict: A dict maps tokens to indices.
+        :param token_cls: The token represents classification.
+        :param token_sep: The token represents separator.
+        :param token_unk: The token represents unknown token.
+        :param pad_index: The index to pad.
+        :param cased: Whether to keep the case.
+        """
+        self._token_dict = token_dict
+        self._token_cls = token_cls
+        self._token_sep = token_sep
+        self._token_unk = token_unk
+        self._pad_index = pad_index
+        self._cased = cased
+
+    @staticmethod
+    def _truncate(first_tokens, second_tokens=None, max_len=None):
+        if max_len is None:
+            return
+
+        if second_tokens is not None:
+            while True:
+                total_len = len(first_tokens) + len(second_tokens)
+                if total_len <= max_len - 3:  # 3 for [CLS] .. tokens_a .. [SEP] .. tokens_b [SEP]
+                    break
+                if len(first_tokens) > len(second_tokens):
+                    first_tokens.pop()
+                else:
+                    second_tokens.pop()
+        else:
+            del first_tokens[max_len - 2:]  # 2 for [CLS] .. tokens .. [SEP]
+
+    def _pack(self, first_tokens, second_tokens=None):
+        first_packed_tokens = [self._token_cls] + first_tokens + [self._token_sep]
+        if second_tokens is not None:
+            second_packed_tokens = second_tokens + [self._token_sep]
+            return first_packed_tokens + second_packed_tokens, len(first_packed_tokens), len(second_packed_tokens)
+        else:
+            return first_packed_tokens, len(first_packed_tokens), 0
+
+    def _convert_tokens_to_ids(self, tokens):
+        unk_id = self._token_dict.get(self._token_unk)
+        return [self._token_dict.get(token, unk_id) for token in tokens]
+
+    def tokenize(self, first, second=None):
+        first_tokens = self._tokenize(first)
+        second_tokens = self._tokenize(second) if second is not None else None
+        tokens, _, _ = self._pack(first_tokens, second_tokens)
+        return tokens
+
+    def encode(self, first, second=None, max_len=None):
+        first_tokens = self._tokenize(first)
+        second_tokens = self._tokenize(second) if second is not None else None
+        self._truncate(first_tokens, second_tokens, max_len)
+        tokens, first_len, second_len = self._pack(first_tokens, second_tokens)
+
+        token_ids = self._convert_tokens_to_ids(tokens)
+        segment_ids = [0] * first_len + [1] * second_len
+
+        if max_len is not None:
+            pad_len = max_len - first_len - second_len
+            token_ids += [self._pad_index] * pad_len
+            segment_ids += [0] * pad_len
+
+        return token_ids, segment_ids
+
+    def _tokenize(self, text):
+        if not self._cased:
+            text = unicodedata.normalize('NFD', text)
+            text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn'])
+            text = text.lower()
+        spaced = ''
+        for ch in text:
+            if self._is_punctuation(ch) or self._is_cjk_character(ch):
+                spaced += ' ' + ch + ' '
+            elif self._is_space(ch):
+                spaced += ' '
+            elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch):
+                continue
+            else:
+                spaced += ch
+        tokens = []
+        for word in spaced.strip().split():
+            tokens += self._word_piece_tokenize(word)
+        return tokens
+
+    def _word_piece_tokenize(self, word):
+        if word in self._token_dict:
+            return [word]
+        tokens = []
+        start, stop = 0, 0
+        while start < len(word):
+            stop = len(word)
+            while stop > start:
+                sub = word[start:stop]
+                if start > 0:
+                    sub = '##' + sub
+                if sub in self._token_dict:
+                    break
+                stop -= 1
+            if start == stop:
+                stop += 1
+            tokens.append(sub)
+            start = stop
+        return tokens
+
+    @staticmethod
+    def _is_punctuation(ch):
+        code = ord(ch)
+        return 33 <= code <= 47 or \
+            58 <= code <= 64 or \
+            91 <= code <= 96 or \
+            123 <= code <= 126 or \
+            unicodedata.category(ch).startswith('P')
+
+    @staticmethod
+    def _is_cjk_character(ch):
+        code = ord(ch)
+        return 0x4E00 <= code <= 0x9FFF or \
+            0x3400 <= code <= 0x4DBF or \
+            0x20000 <= code <= 0x2A6DF or \
+            0x2A700 <= code <= 0x2B73F or \
+            0x2B740 <= code <= 0x2B81F or \
+            0x2B820 <= code <= 0x2CEAF or \
+            0xF900 <= code <= 0xFAFF or \
+            0x2F800 <= code <= 0x2FA1F
+
+    @staticmethod
+    def _is_space(ch):
+        return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \
+            unicodedata.category(ch) == 'Zs'
+
+    @staticmethod
+    def _is_control(ch):
+        return unicodedata.category(ch).startswith('C')