[51873b]: / keras_bert / bert.py

Download this file

400 lines (370 with data), 15.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
import math
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
import numpy as np
from .keras_pos_embd import PositionEmbedding
from .keras_layer_normalization import LayerNormalization
from .keras_transformer import get_encoders
from .keras_transformer import get_custom_objects as get_encoder_custom_objects
from .layers import (get_inputs, get_embedding,
TokenEmbedding, EmbeddingSimilarity, Masked, Extract)
from .keras_multi_head import MultiHeadAttention
from .keras_position_wise_feed_forward import FeedForward
__all__ = [
'TOKEN_PAD', 'TOKEN_UNK', 'TOKEN_CLS', 'TOKEN_SEP', 'TOKEN_MASK',
'gelu', 'get_model', 'get_custom_objects', 'get_base_dict', 'gen_batch_inputs',
]
TOKEN_PAD = '' # Token for padding
TOKEN_UNK = '[UNK]' # Token for unknown words
TOKEN_CLS = '[CLS]' # Token for classification
TOKEN_SEP = '[SEP]' # Token for separation
TOKEN_MASK = '[MASK]' # Token for masking
def gelu(x):
if K.backend() == 'tensorflow':
return 0.5 * x * (1.0 + tf.math.erf(x / tf.sqrt(2.0)))
return 0.5 * x * (1.0 + K.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * K.pow(x, 3))))
class Bert(keras.Model):
def __init__(
self,
token_num,
pos_num=512,
seq_len=512,
embed_dim=768,
transformer_num=12,
head_num=12,
feed_forward_dim=3072,
dropout_rate=0.1,
attention_activation=None,
feed_forward_activation=gelu,
custom_layers=None,
training=True,
trainable=None,
lr=1e-4,
name='Bert'):
super().__init__(name=name)
self.token_num = token_num
self.pos_num = pos_num
self.seq_len = seq_len
self.embed_dim = embed_dim
self.transformer_num = transformer_num
self.head_num = head_num
self.feed_forward_dim = feed_forward_dim
self.dropout_rate = dropout_rate
self.attention_activation = attention_activation
self.feed_forward_activation = feed_forward_activation
self.custom_layers = custom_layers
self.training = training
self.trainable = trainable
self.lr = lr
# build layers
# embedding
self.token_embedding_layer = TokenEmbedding(
input_dim=token_num,
output_dim=embed_dim,
mask_zero=True,
trainable=trainable,
name='Embedding-Token',
)
self.segment_embedding_layer = keras.layers.Embedding(
input_dim=2,
output_dim=embed_dim,
trainable=trainable,
name='Embedding-Segment',
)
self.position_embedding_layer = PositionEmbedding(
input_dim=pos_num,
output_dim=embed_dim,
mode=PositionEmbedding.MODE_ADD,
trainable=trainable,
name='Embedding-Position',
)
self.embedding_layer_norm = LayerNormalization(
trainable=trainable,
name='Embedding-Norm',
)
self.encoder_multihead_layers = []
self.encoder_ffn_layers = []
self.encoder_attention_norm = []
self.encoder_ffn_norm = []
# attention layers
for i in range(transformer_num):
base_name = 'Encoder-%d' % (i + 1)
attention_name = '%s-MultiHeadSelfAttention' % base_name
feed_forward_name = '%s-FeedForward' % base_name
self.encoder_multihead_layers.append(MultiHeadAttention(
head_num=head_num,
activation=attention_activation,
history_only=False,
trainable=trainable,
name=attention_name,
))
self.encoder_ffn_layers.append(FeedForward(
units=feed_forward_dim,
activation=feed_forward_activation,
trainable=trainable,
name=feed_forward_name,
))
self.encoder_attention_norm.append(LayerNormalization(
trainable=trainable,
name='%s-Norm' % attention_name,
))
self.encoder_ffn_norm.append(LayerNormalization(
trainable=trainable,
name='%s-Norm' % feed_forward_name,
))
def call(self, inputs):
embeddings = [
self.token_embedding_layer(inputs[0]),
self.segment_embedding_layer(inputs[1])
]
embeddings[0], embed_weights = embeddings[0]
embed_layer = keras.layers.Add(
name='Embedding-Token-Segment')(embeddings)
embed_layer = self.position_embedding_layer(embed_layer)
if self.dropout_rate > 0.0:
dropout_layer = keras.layers.Dropout(
rate=self.dropout_rate,
name='Embedding-Dropout',
)(embed_layer)
else:
dropout_layer = embed_layer
embedding_output = self.embedding_layer_norm(dropout_layer)
def _wrap_layer(name, input_layer, build_func, norm_layer, dropout_rate=0.0, trainable=True):
"""Wrap layers with residual, normalization and dropout.
:param name: Prefix of names for internal layers.
:param input_layer: Input layer.
:param build_func: A callable that takes the input tensor and generates the output tensor.
:param dropout_rate: Dropout rate.
:param trainable: Whether the layers are trainable.
:return: Output layer.
"""
build_output = build_func(input_layer)
if dropout_rate > 0.0:
dropout_layer = keras.layers.Dropout(
rate=dropout_rate,
name='%s-Dropout' % name,
)(build_output)
else:
dropout_layer = build_output
if isinstance(input_layer, list):
input_layer = input_layer[0]
add_layer = keras.layers.Add(name='%s-Add' %
name)([input_layer, dropout_layer])
normal_layer = norm_layer(add_layer)
return normal_layer
last_layer = embedding_output
output_tensor_list = [last_layer]
# self attention
for i in range(self.transformer_num):
base_name = 'Encoder-%d' % (i + 1)
attention_name = '%s-MultiHeadSelfAttention' % base_name
feed_forward_name = '%s-FeedForward' % base_name
self_attention_output = _wrap_layer(
name=attention_name,
input_layer=last_layer,
build_func=self.encoder_multihead_layers[i],
norm_layer=self.encoder_attention_norm[i],
dropout_rate=self.dropout_rate,
trainable=self.trainable)
last_layer = _wrap_layer(
name=attention_name,
input_layer=self_attention_output,
build_func=self.encoder_ffn_layers[i],
norm_layer=self.encoder_ffn_norm[i],
dropout_rate=self.dropout_rate,
trainable=self.trainable)
output_tensor_list.append(last_layer)
return output_tensor_list
def get_model(token_num,
pos_num=512,
seq_len=512,
embed_dim=768,
transformer_num=12,
head_num=12,
feed_forward_dim=3072,
dropout_rate=0.1,
attention_activation=None,
feed_forward_activation=gelu,
custom_layers=None,
training=True,
trainable=None,
lr=1e-4):
"""Get BERT model.
See: https://arxiv.org/pdf/1810.04805.pdf
:param token_num: Number of tokens.
:param pos_num: Maximum position.
:param seq_len: Maximum length of the input sequence or None.
:param embed_dim: Dimensions of embeddings.
:param transformer_num: Number of transformers.
:param head_num: Number of heads in multi-head attention in each transformer.
:param feed_forward_dim: Dimension of the feed forward layer in each transformer.
:param dropout_rate: Dropout rate.
:param attention_activation: Activation for attention layers.
:param feed_forward_activation: Activation for feed-forward layers.
:param custom_layers: A function that takes the embedding tensor and returns the tensor after feature extraction.
Arguments such as `transformer_num` and `head_num` will be ignored if `custom_layer` is not
`None`.
:param training: The built model will be returned if it is `True`, otherwise the input layers and the last feature
extraction layer will be returned.
:param trainable: Whether the model is trainable.
:param lr: Learning rate.
:return: The compiled model.
"""
if trainable is None:
trainable = training
inputs = get_inputs(seq_len=seq_len)
embed_layer, embed_weights = get_embedding(
inputs,
token_num=token_num,
embed_dim=embed_dim,
pos_num=pos_num,
dropout_rate=dropout_rate,
trainable=trainable,
)
transformed = embed_layer
if custom_layers is not None:
kwargs = {}
if keras.utils.generic_utils.has_arg(custom_layers, 'trainable'):
kwargs['trainable'] = trainable
transformed = custom_layers(transformed, **kwargs)
else:
transformed = get_encoders(
encoder_num=transformer_num,
input_layer=transformed,
head_num=head_num,
hidden_dim=feed_forward_dim,
attention_activation=attention_activation,
feed_forward_activation=feed_forward_activation,
dropout_rate=dropout_rate,
trainable=trainable,
)
if not training:
return inputs, transformed
mlm_dense_layer = keras.layers.Dense(
units=embed_dim,
activation=feed_forward_activation,
trainable=trainable,
name='MLM-Dense',
)(transformed)
mlm_norm_layer = LayerNormalization(name='MLM-Norm')(mlm_dense_layer)
mlm_pred_layer = EmbeddingSimilarity(
name='MLM-Sim')([mlm_norm_layer, embed_weights])
masked_layer = Masked(name='MLM')([mlm_pred_layer, inputs[-1]])
extract_layer = Extract(index=0, name='Extract')(transformed)
nsp_dense_layer = keras.layers.Dense(
units=embed_dim,
activation='tanh',
trainable=trainable,
name='NSP-Dense',
)(extract_layer)
nsp_pred_layer = keras.layers.Dense(
units=2,
activation='softmax',
trainable=trainable,
name='NSP',
)(nsp_dense_layer)
model = keras.models.Model(inputs=inputs, outputs=[
masked_layer, nsp_pred_layer])
model.compile(
optimizer=keras.optimizers.Adam(lr=lr),
loss=keras.losses.sparse_categorical_crossentropy,
)
return model
def get_custom_objects():
"""Get all custom objects for loading saved models."""
custom_objects = get_encoder_custom_objects()
custom_objects['PositionEmbedding'] = PositionEmbedding
custom_objects['TokenEmbedding'] = TokenEmbedding
custom_objects['EmbeddingSimilarity'] = EmbeddingSimilarity
custom_objects['Masked'] = Masked
custom_objects['Extract'] = Extract
custom_objects['gelu'] = gelu
return custom_objects
def get_base_dict():
"""Get basic dictionary containing special tokens."""
return {
TOKEN_PAD: 0,
TOKEN_UNK: 1,
TOKEN_CLS: 2,
TOKEN_SEP: 3,
TOKEN_MASK: 4,
}
def gen_batch_inputs(sentence_pairs,
token_dict,
token_list,
seq_len=512,
mask_rate=0.15,
mask_mask_rate=0.8,
mask_random_rate=0.1,
swap_sentence_rate=0.5,
force_mask=True):
"""Generate a batch of inputs and outputs for training.
:param sentence_pairs: A list of pairs containing lists of tokens.
:param token_dict: The dictionary containing special tokens.
:param token_list: A list containing all tokens.
:param seq_len: Length of the sequence.
:param mask_rate: The rate of choosing a token for prediction.
:param mask_mask_rate: The rate of replacing the token to `TOKEN_MASK`.
:param mask_random_rate: The rate of replacing the token to a random word.
:param swap_sentence_rate: The rate of swapping the second sentences.
:param force_mask: At least one position will be masked.
:return: All the inputs and outputs.
"""
batch_size = len(sentence_pairs)
base_dict = get_base_dict()
unknown_index = token_dict[TOKEN_UNK]
# Generate sentence swapping mapping
nsp_outputs = np.zeros((batch_size,))
mapping = {}
if swap_sentence_rate > 0.0:
indices = [index for index in range(
batch_size) if np.random.random() < swap_sentence_rate]
mapped = indices[:]
np.random.shuffle(mapped)
for i in range(len(mapped)):
if indices[i] != mapped[i]:
nsp_outputs[indices[i]] = 1.0
mapping = {indices[i]: mapped[i] for i in range(len(indices))}
# Generate MLM
token_inputs, segment_inputs, masked_inputs = [], [], []
mlm_outputs = []
for i in range(batch_size):
first, second = sentence_pairs[i][0], sentence_pairs[mapping.get(
i, i)][1]
segment_inputs.append(
([0] * (len(first) + 2) + [1] * (seq_len - (len(first) + 2)))[:seq_len])
tokens = [TOKEN_CLS] + first + [TOKEN_SEP] + second + [TOKEN_SEP]
tokens = tokens[:seq_len]
tokens += [TOKEN_PAD] * (seq_len - len(tokens))
token_input, masked_input, mlm_output = [], [], []
has_mask = False
for token in tokens:
mlm_output.append(token_dict.get(token, unknown_index))
if token not in base_dict and np.random.random() < mask_rate:
has_mask = True
masked_input.append(1)
r = np.random.random()
if r < mask_mask_rate:
token_input.append(token_dict[TOKEN_MASK])
elif r < mask_mask_rate + mask_random_rate:
while True:
token = np.random.choice(token_list)
if token not in base_dict:
token_input.append(token_dict[token])
break
else:
token_input.append(token_dict.get(token, unknown_index))
else:
masked_input.append(0)
token_input.append(token_dict.get(token, unknown_index))
if force_mask and not has_mask:
masked_input[1] = 1
token_inputs.append(token_input)
masked_inputs.append(masked_input)
mlm_outputs.append(mlm_output)
inputs = [np.asarray(x)
for x in [token_inputs, segment_inputs, masked_inputs]]
outputs = [np.asarray(np.expand_dims(x, axis=-1))
for x in [mlm_outputs, nsp_outputs]]
return inputs, outputs