--- a +++ b/keras_bert/tokenizer.py @@ -0,0 +1,149 @@ +import unicodedata +from .bert import TOKEN_CLS, TOKEN_SEP, TOKEN_UNK + + +class Tokenizer(object): + + def __init__(self, + token_dict, + token_cls=TOKEN_CLS, + token_sep=TOKEN_SEP, + token_unk=TOKEN_UNK, + pad_index=0, + cased=False): + """Initialize tokenizer. + + :param token_dict: A dict maps tokens to indices. + :param token_cls: The token represents classification. + :param token_sep: The token represents separator. + :param token_unk: The token represents unknown token. + :param pad_index: The index to pad. + :param cased: Whether to keep the case. + """ + self._token_dict = token_dict + self._token_cls = token_cls + self._token_sep = token_sep + self._token_unk = token_unk + self._pad_index = pad_index + self._cased = cased + + @staticmethod + def _truncate(first_tokens, second_tokens=None, max_len=None): + if max_len is None: + return + + if second_tokens is not None: + while True: + total_len = len(first_tokens) + len(second_tokens) + if total_len <= max_len - 3: # 3 for [CLS] .. tokens_a .. [SEP] .. tokens_b [SEP] + break + if len(first_tokens) > len(second_tokens): + first_tokens.pop() + else: + second_tokens.pop() + else: + del first_tokens[max_len - 2:] # 2 for [CLS] .. tokens .. [SEP] + + def _pack(self, first_tokens, second_tokens=None): + first_packed_tokens = [self._token_cls] + first_tokens + [self._token_sep] + if second_tokens is not None: + second_packed_tokens = second_tokens + [self._token_sep] + return first_packed_tokens + second_packed_tokens, len(first_packed_tokens), len(second_packed_tokens) + else: + return first_packed_tokens, len(first_packed_tokens), 0 + + def _convert_tokens_to_ids(self, tokens): + unk_id = self._token_dict.get(self._token_unk) + return [self._token_dict.get(token, unk_id) for token in tokens] + + def tokenize(self, first, second=None): + first_tokens = self._tokenize(first) + second_tokens = self._tokenize(second) if second is not None else None + tokens, _, _ = self._pack(first_tokens, second_tokens) + return tokens + + def encode(self, first, second=None, max_len=None): + first_tokens = self._tokenize(first) + second_tokens = self._tokenize(second) if second is not None else None + self._truncate(first_tokens, second_tokens, max_len) + tokens, first_len, second_len = self._pack(first_tokens, second_tokens) + + token_ids = self._convert_tokens_to_ids(tokens) + segment_ids = [0] * first_len + [1] * second_len + + if max_len is not None: + pad_len = max_len - first_len - second_len + token_ids += [self._pad_index] * pad_len + segment_ids += [0] * pad_len + + return token_ids, segment_ids + + def _tokenize(self, text): + if not self._cased: + text = unicodedata.normalize('NFD', text) + text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn']) + text = text.lower() + spaced = '' + for ch in text: + if self._is_punctuation(ch) or self._is_cjk_character(ch): + spaced += ' ' + ch + ' ' + elif self._is_space(ch): + spaced += ' ' + elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): + continue + else: + spaced += ch + tokens = [] + for word in spaced.strip().split(): + tokens += self._word_piece_tokenize(word) + return tokens + + def _word_piece_tokenize(self, word): + if word in self._token_dict: + return [word] + tokens = [] + start, stop = 0, 0 + while start < len(word): + stop = len(word) + while stop > start: + sub = word[start:stop] + if start > 0: + sub = '##' + sub + if sub in self._token_dict: + break + stop -= 1 + if start == stop: + stop += 1 + tokens.append(sub) + start = stop + return tokens + + @staticmethod + def _is_punctuation(ch): + code = ord(ch) + return 33 <= code <= 47 or \ + 58 <= code <= 64 or \ + 91 <= code <= 96 or \ + 123 <= code <= 126 or \ + unicodedata.category(ch).startswith('P') + + @staticmethod + def _is_cjk_character(ch): + code = ord(ch) + return 0x4E00 <= code <= 0x9FFF or \ + 0x3400 <= code <= 0x4DBF or \ + 0x20000 <= code <= 0x2A6DF or \ + 0x2A700 <= code <= 0x2B73F or \ + 0x2B740 <= code <= 0x2B81F or \ + 0x2B820 <= code <= 0x2CEAF or \ + 0xF900 <= code <= 0xFAFF or \ + 0x2F800 <= code <= 0x2FA1F + + @staticmethod + def _is_space(ch): + return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ + unicodedata.category(ch) == 'Zs' + + @staticmethod + def _is_control(ch): + return unicodedata.category(ch).startswith('C')