--- a +++ b/keras_bert/bert.py @@ -0,0 +1,399 @@ +import math +import tensorflow as tf +from tensorflow import keras +import tensorflow.keras.backend as K +import numpy as np +from .keras_pos_embd import PositionEmbedding +from .keras_layer_normalization import LayerNormalization +from .keras_transformer import get_encoders +from .keras_transformer import get_custom_objects as get_encoder_custom_objects +from .layers import (get_inputs, get_embedding, + TokenEmbedding, EmbeddingSimilarity, Masked, Extract) +from .keras_multi_head import MultiHeadAttention +from .keras_position_wise_feed_forward import FeedForward + + +__all__ = [ + 'TOKEN_PAD', 'TOKEN_UNK', 'TOKEN_CLS', 'TOKEN_SEP', 'TOKEN_MASK', + 'gelu', 'get_model', 'get_custom_objects', 'get_base_dict', 'gen_batch_inputs', +] + + +TOKEN_PAD = '' # Token for padding +TOKEN_UNK = '[UNK]' # Token for unknown words +TOKEN_CLS = '[CLS]' # Token for classification +TOKEN_SEP = '[SEP]' # Token for separation +TOKEN_MASK = '[MASK]' # Token for masking + + +def gelu(x): + if K.backend() == 'tensorflow': + return 0.5 * x * (1.0 + tf.math.erf(x / tf.sqrt(2.0))) + return 0.5 * x * (1.0 + K.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * K.pow(x, 3)))) + + +class Bert(keras.Model): + def __init__( + self, + token_num, + pos_num=512, + seq_len=512, + embed_dim=768, + transformer_num=12, + head_num=12, + feed_forward_dim=3072, + dropout_rate=0.1, + attention_activation=None, + feed_forward_activation=gelu, + custom_layers=None, + training=True, + trainable=None, + lr=1e-4, + name='Bert'): + super().__init__(name=name) + self.token_num = token_num + self.pos_num = pos_num + self.seq_len = seq_len + self.embed_dim = embed_dim + self.transformer_num = transformer_num + self.head_num = head_num + self.feed_forward_dim = feed_forward_dim + self.dropout_rate = dropout_rate + self.attention_activation = attention_activation + self.feed_forward_activation = feed_forward_activation + self.custom_layers = custom_layers + self.training = training + self.trainable = trainable + self.lr = lr + + # build layers + # embedding + self.token_embedding_layer = TokenEmbedding( + input_dim=token_num, + output_dim=embed_dim, + mask_zero=True, + trainable=trainable, + name='Embedding-Token', + ) + self.segment_embedding_layer = keras.layers.Embedding( + input_dim=2, + output_dim=embed_dim, + trainable=trainable, + name='Embedding-Segment', + ) + self.position_embedding_layer = PositionEmbedding( + input_dim=pos_num, + output_dim=embed_dim, + mode=PositionEmbedding.MODE_ADD, + trainable=trainable, + name='Embedding-Position', + ) + self.embedding_layer_norm = LayerNormalization( + trainable=trainable, + name='Embedding-Norm', + ) + + self.encoder_multihead_layers = [] + self.encoder_ffn_layers = [] + self.encoder_attention_norm = [] + self.encoder_ffn_norm = [] + # attention layers + for i in range(transformer_num): + base_name = 'Encoder-%d' % (i + 1) + attention_name = '%s-MultiHeadSelfAttention' % base_name + feed_forward_name = '%s-FeedForward' % base_name + self.encoder_multihead_layers.append(MultiHeadAttention( + head_num=head_num, + activation=attention_activation, + history_only=False, + trainable=trainable, + name=attention_name, + )) + self.encoder_ffn_layers.append(FeedForward( + units=feed_forward_dim, + activation=feed_forward_activation, + trainable=trainable, + name=feed_forward_name, + )) + self.encoder_attention_norm.append(LayerNormalization( + trainable=trainable, + name='%s-Norm' % attention_name, + )) + self.encoder_ffn_norm.append(LayerNormalization( + trainable=trainable, + name='%s-Norm' % feed_forward_name, + )) + + def call(self, inputs): + + embeddings = [ + self.token_embedding_layer(inputs[0]), + self.segment_embedding_layer(inputs[1]) + ] + embeddings[0], embed_weights = embeddings[0] + embed_layer = keras.layers.Add( + name='Embedding-Token-Segment')(embeddings) + embed_layer = self.position_embedding_layer(embed_layer) + + if self.dropout_rate > 0.0: + dropout_layer = keras.layers.Dropout( + rate=self.dropout_rate, + name='Embedding-Dropout', + )(embed_layer) + else: + dropout_layer = embed_layer + + embedding_output = self.embedding_layer_norm(dropout_layer) + + def _wrap_layer(name, input_layer, build_func, norm_layer, dropout_rate=0.0, trainable=True): + """Wrap layers with residual, normalization and dropout. + + :param name: Prefix of names for internal layers. + :param input_layer: Input layer. + :param build_func: A callable that takes the input tensor and generates the output tensor. + :param dropout_rate: Dropout rate. + :param trainable: Whether the layers are trainable. + :return: Output layer. + """ + build_output = build_func(input_layer) + if dropout_rate > 0.0: + dropout_layer = keras.layers.Dropout( + rate=dropout_rate, + name='%s-Dropout' % name, + )(build_output) + else: + dropout_layer = build_output + if isinstance(input_layer, list): + input_layer = input_layer[0] + add_layer = keras.layers.Add(name='%s-Add' % + name)([input_layer, dropout_layer]) + normal_layer = norm_layer(add_layer) + return normal_layer + + last_layer = embedding_output + output_tensor_list = [last_layer] + # self attention + for i in range(self.transformer_num): + base_name = 'Encoder-%d' % (i + 1) + attention_name = '%s-MultiHeadSelfAttention' % base_name + feed_forward_name = '%s-FeedForward' % base_name + self_attention_output = _wrap_layer( + name=attention_name, + input_layer=last_layer, + build_func=self.encoder_multihead_layers[i], + norm_layer=self.encoder_attention_norm[i], + dropout_rate=self.dropout_rate, + trainable=self.trainable) + last_layer = _wrap_layer( + name=attention_name, + input_layer=self_attention_output, + build_func=self.encoder_ffn_layers[i], + norm_layer=self.encoder_ffn_norm[i], + dropout_rate=self.dropout_rate, + trainable=self.trainable) + output_tensor_list.append(last_layer) + + return output_tensor_list + + +def get_model(token_num, + pos_num=512, + seq_len=512, + embed_dim=768, + transformer_num=12, + head_num=12, + feed_forward_dim=3072, + dropout_rate=0.1, + attention_activation=None, + feed_forward_activation=gelu, + custom_layers=None, + training=True, + trainable=None, + lr=1e-4): + """Get BERT model. + + See: https://arxiv.org/pdf/1810.04805.pdf + + :param token_num: Number of tokens. + :param pos_num: Maximum position. + :param seq_len: Maximum length of the input sequence or None. + :param embed_dim: Dimensions of embeddings. + :param transformer_num: Number of transformers. + :param head_num: Number of heads in multi-head attention in each transformer. + :param feed_forward_dim: Dimension of the feed forward layer in each transformer. + :param dropout_rate: Dropout rate. + :param attention_activation: Activation for attention layers. + :param feed_forward_activation: Activation for feed-forward layers. + :param custom_layers: A function that takes the embedding tensor and returns the tensor after feature extraction. + Arguments such as `transformer_num` and `head_num` will be ignored if `custom_layer` is not + `None`. + :param training: The built model will be returned if it is `True`, otherwise the input layers and the last feature + extraction layer will be returned. + :param trainable: Whether the model is trainable. + :param lr: Learning rate. + :return: The compiled model. + """ + if trainable is None: + trainable = training + inputs = get_inputs(seq_len=seq_len) + embed_layer, embed_weights = get_embedding( + inputs, + token_num=token_num, + embed_dim=embed_dim, + pos_num=pos_num, + dropout_rate=dropout_rate, + trainable=trainable, + ) + transformed = embed_layer + if custom_layers is not None: + kwargs = {} + if keras.utils.generic_utils.has_arg(custom_layers, 'trainable'): + kwargs['trainable'] = trainable + transformed = custom_layers(transformed, **kwargs) + else: + transformed = get_encoders( + encoder_num=transformer_num, + input_layer=transformed, + head_num=head_num, + hidden_dim=feed_forward_dim, + attention_activation=attention_activation, + feed_forward_activation=feed_forward_activation, + dropout_rate=dropout_rate, + trainable=trainable, + ) + if not training: + return inputs, transformed + mlm_dense_layer = keras.layers.Dense( + units=embed_dim, + activation=feed_forward_activation, + trainable=trainable, + name='MLM-Dense', + )(transformed) + mlm_norm_layer = LayerNormalization(name='MLM-Norm')(mlm_dense_layer) + mlm_pred_layer = EmbeddingSimilarity( + name='MLM-Sim')([mlm_norm_layer, embed_weights]) + masked_layer = Masked(name='MLM')([mlm_pred_layer, inputs[-1]]) + extract_layer = Extract(index=0, name='Extract')(transformed) + nsp_dense_layer = keras.layers.Dense( + units=embed_dim, + activation='tanh', + trainable=trainable, + name='NSP-Dense', + )(extract_layer) + nsp_pred_layer = keras.layers.Dense( + units=2, + activation='softmax', + trainable=trainable, + name='NSP', + )(nsp_dense_layer) + model = keras.models.Model(inputs=inputs, outputs=[ + masked_layer, nsp_pred_layer]) + model.compile( + optimizer=keras.optimizers.Adam(lr=lr), + loss=keras.losses.sparse_categorical_crossentropy, + ) + return model + + +def get_custom_objects(): + """Get all custom objects for loading saved models.""" + custom_objects = get_encoder_custom_objects() + custom_objects['PositionEmbedding'] = PositionEmbedding + custom_objects['TokenEmbedding'] = TokenEmbedding + custom_objects['EmbeddingSimilarity'] = EmbeddingSimilarity + custom_objects['Masked'] = Masked + custom_objects['Extract'] = Extract + custom_objects['gelu'] = gelu + return custom_objects + + +def get_base_dict(): + """Get basic dictionary containing special tokens.""" + return { + TOKEN_PAD: 0, + TOKEN_UNK: 1, + TOKEN_CLS: 2, + TOKEN_SEP: 3, + TOKEN_MASK: 4, + } + + +def gen_batch_inputs(sentence_pairs, + token_dict, + token_list, + seq_len=512, + mask_rate=0.15, + mask_mask_rate=0.8, + mask_random_rate=0.1, + swap_sentence_rate=0.5, + force_mask=True): + """Generate a batch of inputs and outputs for training. + + :param sentence_pairs: A list of pairs containing lists of tokens. + :param token_dict: The dictionary containing special tokens. + :param token_list: A list containing all tokens. + :param seq_len: Length of the sequence. + :param mask_rate: The rate of choosing a token for prediction. + :param mask_mask_rate: The rate of replacing the token to `TOKEN_MASK`. + :param mask_random_rate: The rate of replacing the token to a random word. + :param swap_sentence_rate: The rate of swapping the second sentences. + :param force_mask: At least one position will be masked. + :return: All the inputs and outputs. + """ + batch_size = len(sentence_pairs) + base_dict = get_base_dict() + unknown_index = token_dict[TOKEN_UNK] + # Generate sentence swapping mapping + nsp_outputs = np.zeros((batch_size,)) + mapping = {} + if swap_sentence_rate > 0.0: + indices = [index for index in range( + batch_size) if np.random.random() < swap_sentence_rate] + mapped = indices[:] + np.random.shuffle(mapped) + for i in range(len(mapped)): + if indices[i] != mapped[i]: + nsp_outputs[indices[i]] = 1.0 + mapping = {indices[i]: mapped[i] for i in range(len(indices))} + # Generate MLM + token_inputs, segment_inputs, masked_inputs = [], [], [] + mlm_outputs = [] + for i in range(batch_size): + first, second = sentence_pairs[i][0], sentence_pairs[mapping.get( + i, i)][1] + segment_inputs.append( + ([0] * (len(first) + 2) + [1] * (seq_len - (len(first) + 2)))[:seq_len]) + tokens = [TOKEN_CLS] + first + [TOKEN_SEP] + second + [TOKEN_SEP] + tokens = tokens[:seq_len] + tokens += [TOKEN_PAD] * (seq_len - len(tokens)) + token_input, masked_input, mlm_output = [], [], [] + has_mask = False + for token in tokens: + mlm_output.append(token_dict.get(token, unknown_index)) + if token not in base_dict and np.random.random() < mask_rate: + has_mask = True + masked_input.append(1) + r = np.random.random() + if r < mask_mask_rate: + token_input.append(token_dict[TOKEN_MASK]) + elif r < mask_mask_rate + mask_random_rate: + while True: + token = np.random.choice(token_list) + if token not in base_dict: + token_input.append(token_dict[token]) + break + else: + token_input.append(token_dict.get(token, unknown_index)) + else: + masked_input.append(0) + token_input.append(token_dict.get(token, unknown_index)) + if force_mask and not has_mask: + masked_input[1] = 1 + token_inputs.append(token_input) + masked_inputs.append(masked_input) + mlm_outputs.append(mlm_output) + inputs = [np.asarray(x) + for x in [token_inputs, segment_inputs, masked_inputs]] + outputs = [np.asarray(np.expand_dims(x, axis=-1)) + for x in [mlm_outputs, nsp_outputs]] + return inputs, outputs