Diff of /keras_bert/loader.py [000000] .. [51873b]

Switch to side-by-side view

--- a
+++ b/keras_bert/loader.py
@@ -0,0 +1,156 @@
+import json
+from tensorflow import keras
+import numpy as np
+import tensorflow as tf
+from .bert import get_model
+
+
+__all__ = [
+    'build_model_from_config',
+    'load_model_weights_from_checkpoint',
+    'load_trained_model_from_checkpoint',
+]
+
+
+def checkpoint_loader(checkpoint_file):
+    def _loader(name):
+        return tf.train.load_variable(checkpoint_file, name)
+    return _loader
+
+
+def build_model_from_config(config_file,
+                            training=False,
+                            trainable=None,
+                            seq_len=None):
+    """Build the model from config file.
+
+    :param config_file: The path to the JSON configuration file.
+    :param training: If training, the whole model will be returned.
+    :param trainable: Whether the model is trainable.
+    :param seq_len: If it is not None and it is shorter than the value in the config file, the weights in
+                    position embeddings will be sliced to fit the new length.
+    :return: model and config
+    """
+    with open(config_file, 'r') as reader:
+        config = json.loads(reader.read())
+    if seq_len is not None:
+        config['max_position_embeddings'] = min(
+            seq_len, config['max_position_embeddings'])
+    if trainable is None:
+        trainable = training
+
+    model = get_model(
+        token_num=config['vocab_size'],
+        pos_num=config['max_position_embeddings'],
+        seq_len=config['max_position_embeddings'],
+        embed_dim=config['hidden_size'],
+        transformer_num=config['num_hidden_layers'],
+        head_num=config['num_attention_heads'],
+        feed_forward_dim=config['intermediate_size'],
+        training=training,
+        trainable=trainable,
+    )
+    model.build(input_shape=[(None, None), (None, None), (None, None)])
+    return model, config
+
+
+def load_model_weights_from_checkpoint(model,
+                                       config,
+                                       checkpoint_file,
+                                       training=False):
+    """Load trained official model from checkpoint.
+
+    :param model: Built keras model.
+    :param config: Loaded configuration file.
+    :param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'.
+    :param training: If training, the whole model will be returned.
+                     Otherwise, the MLM and NSP parts will be ignored.
+    """
+    loader = checkpoint_loader(checkpoint_file)
+
+    model.get_layer(name='Embedding-Token').set_weights([
+        loader('bert/embeddings/word_embeddings'),
+    ])
+    model.get_layer(name='Embedding-Position').set_weights([
+        loader(
+            'bert/embeddings/position_embeddings')[:config['max_position_embeddings'], :],
+    ])
+    model.get_layer(name='Embedding-Segment').set_weights([
+        loader('bert/embeddings/token_type_embeddings'),
+    ])
+    model.get_layer(name='Embedding-Norm').set_weights([
+        loader('bert/embeddings/LayerNorm/gamma'),
+        loader('bert/embeddings/LayerNorm/beta'),
+    ])
+    for i in range(config['num_hidden_layers']):
+        model.get_layer(name='Encoder-%d-MultiHeadSelfAttention' % (i + 1)).set_weights([
+            loader('bert/encoder/layer_%d/attention/self/query/kernel' % i),
+            loader('bert/encoder/layer_%d/attention/self/query/bias' % i),
+            loader('bert/encoder/layer_%d/attention/self/key/kernel' % i),
+            loader('bert/encoder/layer_%d/attention/self/key/bias' % i),
+            loader('bert/encoder/layer_%d/attention/self/value/kernel' % i),
+            loader('bert/encoder/layer_%d/attention/self/value/bias' % i),
+            loader('bert/encoder/layer_%d/attention/output/dense/kernel' % i),
+            loader('bert/encoder/layer_%d/attention/output/dense/bias' % i),
+        ])
+        model.get_layer(name='Encoder-%d-MultiHeadSelfAttention-Norm' % (i + 1)).set_weights([
+            loader('bert/encoder/layer_%d/attention/output/LayerNorm/gamma' % i),
+            loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' % i),
+        ])
+        model.get_layer(name='Encoder-%d-MultiHeadSelfAttention-Norm' % (i + 1)).set_weights([
+            loader('bert/encoder/layer_%d/attention/output/LayerNorm/gamma' % i),
+            loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' % i),
+        ])
+        model.get_layer(name='Encoder-%d-FeedForward' % (i + 1)).set_weights([
+            loader('bert/encoder/layer_%d/intermediate/dense/kernel' % i),
+            loader('bert/encoder/layer_%d/intermediate/dense/bias' % i),
+            loader('bert/encoder/layer_%d/output/dense/kernel' % i),
+            loader('bert/encoder/layer_%d/output/dense/bias' % i),
+        ])
+        model.get_layer(name='Encoder-%d-FeedForward-Norm' % (i + 1)).set_weights([
+            loader('bert/encoder/layer_%d/output/LayerNorm/gamma' % i),
+            loader('bert/encoder/layer_%d/output/LayerNorm/beta' % i),
+        ])
+    if training:
+        model.get_layer(name='MLM-Dense').set_weights([
+            loader('cls/predictions/transform/dense/kernel'),
+            loader('cls/predictions/transform/dense/bias'),
+        ])
+        model.get_layer(name='MLM-Norm').set_weights([
+            loader('cls/predictions/transform/LayerNorm/gamma'),
+            loader('cls/predictions/transform/LayerNorm/beta'),
+        ])
+        model.get_layer(name='MLM-Sim').set_weights([
+            loader('cls/predictions/output_bias'),
+        ])
+        model.get_layer(name='NSP-Dense').set_weights([
+            loader('bert/pooler/dense/kernel'),
+            loader('bert/pooler/dense/bias'),
+        ])
+        model.get_layer(name='NSP').set_weights([
+            np.transpose(loader('cls/seq_relationship/output_weights')),
+            loader('cls/seq_relationship/output_bias'),
+        ])
+
+
+def load_trained_model_from_checkpoint(config_file,
+                                       checkpoint_file,
+                                       training=False,
+                                       trainable=None,
+                                       seq_len=None):
+    """Load trained official model from checkpoint.
+
+    :param config_file: The path to the JSON configuration file.
+    :param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'.
+    :param training: If training, the whole model will be returned.
+                     Otherwise, the MLM and NSP parts will be ignored.
+    :param trainable: Whether the model is trainable. The default value is the same with `training`.
+    :param seq_len: If it is not None and it is shorter than the value in the config file, the weights in
+                    position embeddings will be sliced to fit the new length.
+    :return: model
+    """
+    model, config = build_model_from_config(
+        config_file, training=training, trainable=trainable, seq_len=seq_len)
+    load_model_weights_from_checkpoint(
+        model, config, checkpoint_file, training=training)
+    return model