|
a |
|
b/docproduct/tokenization.py |
|
|
1 |
# coding=utf-8 |
|
|
2 |
# Copyright 2018 The Google AI Language Team Authors. |
|
|
3 |
# |
|
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
5 |
# you may not use this file except in compliance with the License. |
|
|
6 |
# You may obtain a copy of the License at |
|
|
7 |
# |
|
|
8 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
9 |
# |
|
|
10 |
# Unless required by applicable law or agreed to in writing, software |
|
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
13 |
# See the License for the specific language governing permissions and |
|
|
14 |
# limitations under the License. |
|
|
15 |
"""Tokenization classes.""" |
|
|
16 |
|
|
|
17 |
from __future__ import absolute_import |
|
|
18 |
from __future__ import division |
|
|
19 |
from __future__ import print_function |
|
|
20 |
|
|
|
21 |
import collections |
|
|
22 |
import re |
|
|
23 |
import unicodedata |
|
|
24 |
import six |
|
|
25 |
import tensorflow as tf |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): |
|
|
29 |
"""Checks whether the casing config is consistent with the checkpoint name.""" |
|
|
30 |
|
|
|
31 |
# The casing has to be passed in by the user and there is no explicit check |
|
|
32 |
# as to whether it matches the checkpoint. The casing information probably |
|
|
33 |
# should have been stored in the bert_config.json file, but it's not, so |
|
|
34 |
# we have to heuristically detect it to validate. |
|
|
35 |
|
|
|
36 |
if not init_checkpoint: |
|
|
37 |
return |
|
|
38 |
|
|
|
39 |
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) |
|
|
40 |
if m is None: |
|
|
41 |
return |
|
|
42 |
|
|
|
43 |
model_name = m.group(1) |
|
|
44 |
|
|
|
45 |
lower_models = [ |
|
|
46 |
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", |
|
|
47 |
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" |
|
|
48 |
] |
|
|
49 |
|
|
|
50 |
cased_models = [ |
|
|
51 |
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", |
|
|
52 |
"multi_cased_L-12_H-768_A-12" |
|
|
53 |
] |
|
|
54 |
|
|
|
55 |
is_bad_config = False |
|
|
56 |
if model_name in lower_models and not do_lower_case: |
|
|
57 |
is_bad_config = True |
|
|
58 |
actual_flag = "False" |
|
|
59 |
case_name = "lowercased" |
|
|
60 |
opposite_flag = "True" |
|
|
61 |
|
|
|
62 |
if model_name in cased_models and do_lower_case: |
|
|
63 |
is_bad_config = True |
|
|
64 |
actual_flag = "True" |
|
|
65 |
case_name = "cased" |
|
|
66 |
opposite_flag = "False" |
|
|
67 |
|
|
|
68 |
if is_bad_config: |
|
|
69 |
raise ValueError( |
|
|
70 |
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " |
|
|
71 |
"However, `%s` seems to be a %s model, so you " |
|
|
72 |
"should pass in `--do_lower_case=%s` so that the fine-tuning matches " |
|
|
73 |
"how the model was pre-training. If this error is wrong, please " |
|
|
74 |
"just comment out this check." % (actual_flag, init_checkpoint, |
|
|
75 |
model_name, case_name, opposite_flag)) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
def convert_to_unicode(text): |
|
|
79 |
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" |
|
|
80 |
if six.PY3: |
|
|
81 |
if isinstance(text, str): |
|
|
82 |
return text |
|
|
83 |
elif isinstance(text, bytes): |
|
|
84 |
return text.decode("utf-8", "ignore") |
|
|
85 |
else: |
|
|
86 |
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
87 |
elif six.PY2: |
|
|
88 |
if isinstance(text, str): |
|
|
89 |
return text.decode("utf-8", "ignore") |
|
|
90 |
elif isinstance(text, unicode): |
|
|
91 |
return text |
|
|
92 |
else: |
|
|
93 |
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
94 |
else: |
|
|
95 |
raise ValueError("Not running on Python2 or Python 3?") |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def printable_text(text): |
|
|
99 |
"""Returns text encoded in a way suitable for print or `tf.logging`.""" |
|
|
100 |
|
|
|
101 |
# These functions want `str` for both Python2 and Python3, but in one case |
|
|
102 |
# it's a Unicode string and in the other it's a byte string. |
|
|
103 |
if six.PY3: |
|
|
104 |
if isinstance(text, str): |
|
|
105 |
return text |
|
|
106 |
elif isinstance(text, bytes): |
|
|
107 |
return text.decode("utf-8", "ignore") |
|
|
108 |
else: |
|
|
109 |
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
110 |
elif six.PY2: |
|
|
111 |
if isinstance(text, str): |
|
|
112 |
return text |
|
|
113 |
elif isinstance(text, unicode): |
|
|
114 |
return text.encode("utf-8") |
|
|
115 |
else: |
|
|
116 |
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
117 |
else: |
|
|
118 |
raise ValueError("Not running on Python2 or Python 3?") |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
def load_vocab(vocab_file): |
|
|
122 |
"""Loads a vocabulary file into a dictionary.""" |
|
|
123 |
vocab = collections.OrderedDict() |
|
|
124 |
index = 0 |
|
|
125 |
with open(vocab_file, "r", encoding='utf-8') as reader: |
|
|
126 |
while True: |
|
|
127 |
token = convert_to_unicode(reader.readline()) |
|
|
128 |
if not token: |
|
|
129 |
break |
|
|
130 |
token = token.strip() |
|
|
131 |
vocab[token] = index |
|
|
132 |
index += 1 |
|
|
133 |
return vocab |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
def convert_by_vocab(vocab, items): |
|
|
137 |
"""Converts a sequence of [tokens|ids] using the vocab.""" |
|
|
138 |
output = [] |
|
|
139 |
for item in items: |
|
|
140 |
output.append(vocab[item]) |
|
|
141 |
return output |
|
|
142 |
|
|
|
143 |
|
|
|
144 |
def convert_tokens_to_ids(vocab, tokens): |
|
|
145 |
return convert_by_vocab(vocab, tokens) |
|
|
146 |
|
|
|
147 |
|
|
|
148 |
def convert_ids_to_tokens(inv_vocab, ids): |
|
|
149 |
return convert_by_vocab(inv_vocab, ids) |
|
|
150 |
|
|
|
151 |
|
|
|
152 |
def whitespace_tokenize(text): |
|
|
153 |
"""Runs basic whitespace cleaning and splitting on a piece of text.""" |
|
|
154 |
text = text.strip() |
|
|
155 |
if not text: |
|
|
156 |
return [] |
|
|
157 |
tokens = text.split() |
|
|
158 |
return tokens |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
class FullTokenizer(object): |
|
|
162 |
"""Runs end-to-end tokenziation.""" |
|
|
163 |
|
|
|
164 |
def __init__(self, vocab_file, do_lower_case=True): |
|
|
165 |
self.vocab = load_vocab(vocab_file) |
|
|
166 |
self.inv_vocab = {v: k for k, v in self.vocab.items()} |
|
|
167 |
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) |
|
|
168 |
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) |
|
|
169 |
|
|
|
170 |
def tokenize(self, text): |
|
|
171 |
split_tokens = [] |
|
|
172 |
for token in self.basic_tokenizer.tokenize(text): |
|
|
173 |
for sub_token in self.wordpiece_tokenizer.tokenize(token): |
|
|
174 |
split_tokens.append(sub_token) |
|
|
175 |
|
|
|
176 |
return split_tokens |
|
|
177 |
|
|
|
178 |
def convert_tokens_to_ids(self, tokens): |
|
|
179 |
return convert_by_vocab(self.vocab, tokens) |
|
|
180 |
|
|
|
181 |
def convert_ids_to_tokens(self, ids): |
|
|
182 |
return convert_by_vocab(self.inv_vocab, ids) |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
class BasicTokenizer(object): |
|
|
186 |
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" |
|
|
187 |
|
|
|
188 |
def __init__(self, do_lower_case=True): |
|
|
189 |
"""Constructs a BasicTokenizer. |
|
|
190 |
Args: |
|
|
191 |
do_lower_case: Whether to lower case the input. |
|
|
192 |
""" |
|
|
193 |
self.do_lower_case = do_lower_case |
|
|
194 |
|
|
|
195 |
def tokenize(self, text): |
|
|
196 |
"""Tokenizes a piece of text.""" |
|
|
197 |
text = convert_to_unicode(text) |
|
|
198 |
text = self._clean_text(text) |
|
|
199 |
|
|
|
200 |
# This was added on November 1st, 2018 for the multilingual and Chinese |
|
|
201 |
# models. This is also applied to the English models now, but it doesn't |
|
|
202 |
# matter since the English models were not trained on any Chinese data |
|
|
203 |
# and generally don't have any Chinese data in them (there are Chinese |
|
|
204 |
# characters in the vocabulary because Wikipedia does have some Chinese |
|
|
205 |
# words in the English Wikipedia.). |
|
|
206 |
text = self._tokenize_chinese_chars(text) |
|
|
207 |
|
|
|
208 |
orig_tokens = whitespace_tokenize(text) |
|
|
209 |
split_tokens = [] |
|
|
210 |
for token in orig_tokens: |
|
|
211 |
if self.do_lower_case: |
|
|
212 |
token = token.lower() |
|
|
213 |
token = self._run_strip_accents(token) |
|
|
214 |
split_tokens.extend(self._run_split_on_punc(token)) |
|
|
215 |
|
|
|
216 |
output_tokens = whitespace_tokenize(" ".join(split_tokens)) |
|
|
217 |
return output_tokens |
|
|
218 |
|
|
|
219 |
def _run_strip_accents(self, text): |
|
|
220 |
"""Strips accents from a piece of text.""" |
|
|
221 |
text = unicodedata.normalize("NFD", text) |
|
|
222 |
output = [] |
|
|
223 |
for char in text: |
|
|
224 |
cat = unicodedata.category(char) |
|
|
225 |
if cat == "Mn": |
|
|
226 |
continue |
|
|
227 |
output.append(char) |
|
|
228 |
return "".join(output) |
|
|
229 |
|
|
|
230 |
def _run_split_on_punc(self, text): |
|
|
231 |
"""Splits punctuation on a piece of text.""" |
|
|
232 |
chars = list(text) |
|
|
233 |
i = 0 |
|
|
234 |
start_new_word = True |
|
|
235 |
output = [] |
|
|
236 |
while i < len(chars): |
|
|
237 |
char = chars[i] |
|
|
238 |
if _is_punctuation(char): |
|
|
239 |
output.append([char]) |
|
|
240 |
start_new_word = True |
|
|
241 |
else: |
|
|
242 |
if start_new_word: |
|
|
243 |
output.append([]) |
|
|
244 |
start_new_word = False |
|
|
245 |
output[-1].append(char) |
|
|
246 |
i += 1 |
|
|
247 |
|
|
|
248 |
return ["".join(x) for x in output] |
|
|
249 |
|
|
|
250 |
def _tokenize_chinese_chars(self, text): |
|
|
251 |
"""Adds whitespace around any CJK character.""" |
|
|
252 |
output = [] |
|
|
253 |
for char in text: |
|
|
254 |
cp = ord(char) |
|
|
255 |
if self._is_chinese_char(cp): |
|
|
256 |
output.append(" ") |
|
|
257 |
output.append(char) |
|
|
258 |
output.append(" ") |
|
|
259 |
else: |
|
|
260 |
output.append(char) |
|
|
261 |
return "".join(output) |
|
|
262 |
|
|
|
263 |
def _is_chinese_char(self, cp): |
|
|
264 |
"""Checks whether CP is the codepoint of a CJK character.""" |
|
|
265 |
# This defines a "chinese character" as anything in the CJK Unicode block: |
|
|
266 |
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) |
|
|
267 |
# |
|
|
268 |
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, |
|
|
269 |
# despite its name. The modern Korean Hangul alphabet is a different block, |
|
|
270 |
# as is Japanese Hiragana and Katakana. Those alphabets are used to write |
|
|
271 |
# space-separated words, so they are not treated specially and handled |
|
|
272 |
# like the all of the other languages. |
|
|
273 |
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # |
|
|
274 |
(cp >= 0x3400 and cp <= 0x4DBF) or # |
|
|
275 |
(cp >= 0x20000 and cp <= 0x2A6DF) or # |
|
|
276 |
(cp >= 0x2A700 and cp <= 0x2B73F) or # |
|
|
277 |
(cp >= 0x2B740 and cp <= 0x2B81F) or # |
|
|
278 |
(cp >= 0x2B820 and cp <= 0x2CEAF) or |
|
|
279 |
(cp >= 0xF900 and cp <= 0xFAFF) or # |
|
|
280 |
(cp >= 0x2F800 and cp <= 0x2FA1F)): # |
|
|
281 |
return True |
|
|
282 |
|
|
|
283 |
return False |
|
|
284 |
|
|
|
285 |
def _clean_text(self, text): |
|
|
286 |
"""Performs invalid character removal and whitespace cleanup on text.""" |
|
|
287 |
output = [] |
|
|
288 |
for char in text: |
|
|
289 |
cp = ord(char) |
|
|
290 |
if cp == 0 or cp == 0xfffd or _is_control(char): |
|
|
291 |
continue |
|
|
292 |
if _is_whitespace(char): |
|
|
293 |
output.append(" ") |
|
|
294 |
else: |
|
|
295 |
output.append(char) |
|
|
296 |
return "".join(output) |
|
|
297 |
|
|
|
298 |
|
|
|
299 |
class WordpieceTokenizer(object): |
|
|
300 |
"""Runs WordPiece tokenziation.""" |
|
|
301 |
|
|
|
302 |
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): |
|
|
303 |
self.vocab = vocab |
|
|
304 |
self.unk_token = unk_token |
|
|
305 |
self.max_input_chars_per_word = max_input_chars_per_word |
|
|
306 |
|
|
|
307 |
def tokenize(self, text): |
|
|
308 |
"""Tokenizes a piece of text into its word pieces. |
|
|
309 |
This uses a greedy longest-match-first algorithm to perform tokenization |
|
|
310 |
using the given vocabulary. |
|
|
311 |
For example: |
|
|
312 |
input = "unaffable" |
|
|
313 |
output = ["un", "##aff", "##able"] |
|
|
314 |
Args: |
|
|
315 |
text: A single token or whitespace separated tokens. This should have |
|
|
316 |
already been passed through `BasicTokenizer. |
|
|
317 |
Returns: |
|
|
318 |
A list of wordpiece tokens. |
|
|
319 |
""" |
|
|
320 |
|
|
|
321 |
text = convert_to_unicode(text) |
|
|
322 |
|
|
|
323 |
output_tokens = [] |
|
|
324 |
for token in whitespace_tokenize(text): |
|
|
325 |
chars = list(token) |
|
|
326 |
if len(chars) > self.max_input_chars_per_word: |
|
|
327 |
output_tokens.append(self.unk_token) |
|
|
328 |
continue |
|
|
329 |
|
|
|
330 |
is_bad = False |
|
|
331 |
start = 0 |
|
|
332 |
sub_tokens = [] |
|
|
333 |
while start < len(chars): |
|
|
334 |
end = len(chars) |
|
|
335 |
cur_substr = None |
|
|
336 |
while start < end: |
|
|
337 |
substr = "".join(chars[start:end]) |
|
|
338 |
if start > 0: |
|
|
339 |
substr = "##" + substr |
|
|
340 |
if substr in self.vocab: |
|
|
341 |
cur_substr = substr |
|
|
342 |
break |
|
|
343 |
end -= 1 |
|
|
344 |
if cur_substr is None: |
|
|
345 |
is_bad = True |
|
|
346 |
break |
|
|
347 |
sub_tokens.append(cur_substr) |
|
|
348 |
start = end |
|
|
349 |
|
|
|
350 |
if is_bad: |
|
|
351 |
output_tokens.append(self.unk_token) |
|
|
352 |
else: |
|
|
353 |
output_tokens.extend(sub_tokens) |
|
|
354 |
return output_tokens |
|
|
355 |
|
|
|
356 |
|
|
|
357 |
def _is_whitespace(char): |
|
|
358 |
"""Checks whether `chars` is a whitespace character.""" |
|
|
359 |
# \t, \n, and \r are technically contorl characters but we treat them |
|
|
360 |
# as whitespace since they are generally considered as such. |
|
|
361 |
if char == " " or char == "\t" or char == "\n" or char == "\r": |
|
|
362 |
return True |
|
|
363 |
cat = unicodedata.category(char) |
|
|
364 |
if cat == "Zs": |
|
|
365 |
return True |
|
|
366 |
return False |
|
|
367 |
|
|
|
368 |
|
|
|
369 |
def _is_control(char): |
|
|
370 |
"""Checks whether `chars` is a control character.""" |
|
|
371 |
# These are technically control characters but we count them as whitespace |
|
|
372 |
# characters. |
|
|
373 |
if char == "\t" or char == "\n" or char == "\r": |
|
|
374 |
return False |
|
|
375 |
cat = unicodedata.category(char) |
|
|
376 |
if cat in ("Cc", "Cf"): |
|
|
377 |
return True |
|
|
378 |
return False |
|
|
379 |
|
|
|
380 |
|
|
|
381 |
def _is_punctuation(char): |
|
|
382 |
"""Checks whether `chars` is a punctuation character.""" |
|
|
383 |
cp = ord(char) |
|
|
384 |
# We treat all non-letter/number ASCII as punctuation. |
|
|
385 |
# Characters such as "^", "$", and "`" are not in the Unicode |
|
|
386 |
# Punctuation class but we treat them as punctuation anyways, for |
|
|
387 |
# consistency. |
|
|
388 |
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or |
|
|
389 |
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): |
|
|
390 |
return True |
|
|
391 |
cat = unicodedata.category(char) |
|
|
392 |
if cat.startswith("P"): |
|
|
393 |
return True |
|
|
394 |
return False |