a b/keras_bert/loader.py
1
import json
2
from tensorflow import keras
3
import numpy as np
4
import tensorflow as tf
5
from .bert import get_model
6
7
8
__all__ = [
9
    'build_model_from_config',
10
    'load_model_weights_from_checkpoint',
11
    'load_trained_model_from_checkpoint',
12
]
13
14
15
def checkpoint_loader(checkpoint_file):
16
    def _loader(name):
17
        return tf.train.load_variable(checkpoint_file, name)
18
    return _loader
19
20
21
def build_model_from_config(config_file,
22
                            training=False,
23
                            trainable=None,
24
                            seq_len=None):
25
    """Build the model from config file.
26
27
    :param config_file: The path to the JSON configuration file.
28
    :param training: If training, the whole model will be returned.
29
    :param trainable: Whether the model is trainable.
30
    :param seq_len: If it is not None and it is shorter than the value in the config file, the weights in
31
                    position embeddings will be sliced to fit the new length.
32
    :return: model and config
33
    """
34
    with open(config_file, 'r') as reader:
35
        config = json.loads(reader.read())
36
    if seq_len is not None:
37
        config['max_position_embeddings'] = min(
38
            seq_len, config['max_position_embeddings'])
39
    if trainable is None:
40
        trainable = training
41
42
    model = get_model(
43
        token_num=config['vocab_size'],
44
        pos_num=config['max_position_embeddings'],
45
        seq_len=config['max_position_embeddings'],
46
        embed_dim=config['hidden_size'],
47
        transformer_num=config['num_hidden_layers'],
48
        head_num=config['num_attention_heads'],
49
        feed_forward_dim=config['intermediate_size'],
50
        training=training,
51
        trainable=trainable,
52
    )
53
    model.build(input_shape=[(None, None), (None, None), (None, None)])
54
    return model, config
55
56
57
def load_model_weights_from_checkpoint(model,
58
                                       config,
59
                                       checkpoint_file,
60
                                       training=False):
61
    """Load trained official model from checkpoint.
62
63
    :param model: Built keras model.
64
    :param config: Loaded configuration file.
65
    :param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'.
66
    :param training: If training, the whole model will be returned.
67
                     Otherwise, the MLM and NSP parts will be ignored.
68
    """
69
    loader = checkpoint_loader(checkpoint_file)
70
71
    model.get_layer(name='Embedding-Token').set_weights([
72
        loader('bert/embeddings/word_embeddings'),
73
    ])
74
    model.get_layer(name='Embedding-Position').set_weights([
75
        loader(
76
            'bert/embeddings/position_embeddings')[:config['max_position_embeddings'], :],
77
    ])
78
    model.get_layer(name='Embedding-Segment').set_weights([
79
        loader('bert/embeddings/token_type_embeddings'),
80
    ])
81
    model.get_layer(name='Embedding-Norm').set_weights([
82
        loader('bert/embeddings/LayerNorm/gamma'),
83
        loader('bert/embeddings/LayerNorm/beta'),
84
    ])
85
    for i in range(config['num_hidden_layers']):
86
        model.get_layer(name='Encoder-%d-MultiHeadSelfAttention' % (i + 1)).set_weights([
87
            loader('bert/encoder/layer_%d/attention/self/query/kernel' % i),
88
            loader('bert/encoder/layer_%d/attention/self/query/bias' % i),
89
            loader('bert/encoder/layer_%d/attention/self/key/kernel' % i),
90
            loader('bert/encoder/layer_%d/attention/self/key/bias' % i),
91
            loader('bert/encoder/layer_%d/attention/self/value/kernel' % i),
92
            loader('bert/encoder/layer_%d/attention/self/value/bias' % i),
93
            loader('bert/encoder/layer_%d/attention/output/dense/kernel' % i),
94
            loader('bert/encoder/layer_%d/attention/output/dense/bias' % i),
95
        ])
96
        model.get_layer(name='Encoder-%d-MultiHeadSelfAttention-Norm' % (i + 1)).set_weights([
97
            loader('bert/encoder/layer_%d/attention/output/LayerNorm/gamma' % i),
98
            loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' % i),
99
        ])
100
        model.get_layer(name='Encoder-%d-MultiHeadSelfAttention-Norm' % (i + 1)).set_weights([
101
            loader('bert/encoder/layer_%d/attention/output/LayerNorm/gamma' % i),
102
            loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' % i),
103
        ])
104
        model.get_layer(name='Encoder-%d-FeedForward' % (i + 1)).set_weights([
105
            loader('bert/encoder/layer_%d/intermediate/dense/kernel' % i),
106
            loader('bert/encoder/layer_%d/intermediate/dense/bias' % i),
107
            loader('bert/encoder/layer_%d/output/dense/kernel' % i),
108
            loader('bert/encoder/layer_%d/output/dense/bias' % i),
109
        ])
110
        model.get_layer(name='Encoder-%d-FeedForward-Norm' % (i + 1)).set_weights([
111
            loader('bert/encoder/layer_%d/output/LayerNorm/gamma' % i),
112
            loader('bert/encoder/layer_%d/output/LayerNorm/beta' % i),
113
        ])
114
    if training:
115
        model.get_layer(name='MLM-Dense').set_weights([
116
            loader('cls/predictions/transform/dense/kernel'),
117
            loader('cls/predictions/transform/dense/bias'),
118
        ])
119
        model.get_layer(name='MLM-Norm').set_weights([
120
            loader('cls/predictions/transform/LayerNorm/gamma'),
121
            loader('cls/predictions/transform/LayerNorm/beta'),
122
        ])
123
        model.get_layer(name='MLM-Sim').set_weights([
124
            loader('cls/predictions/output_bias'),
125
        ])
126
        model.get_layer(name='NSP-Dense').set_weights([
127
            loader('bert/pooler/dense/kernel'),
128
            loader('bert/pooler/dense/bias'),
129
        ])
130
        model.get_layer(name='NSP').set_weights([
131
            np.transpose(loader('cls/seq_relationship/output_weights')),
132
            loader('cls/seq_relationship/output_bias'),
133
        ])
134
135
136
def load_trained_model_from_checkpoint(config_file,
137
                                       checkpoint_file,
138
                                       training=False,
139
                                       trainable=None,
140
                                       seq_len=None):
141
    """Load trained official model from checkpoint.
142
143
    :param config_file: The path to the JSON configuration file.
144
    :param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'.
145
    :param training: If training, the whole model will be returned.
146
                     Otherwise, the MLM and NSP parts will be ignored.
147
    :param trainable: Whether the model is trainable. The default value is the same with `training`.
148
    :param seq_len: If it is not None and it is shorter than the value in the config file, the weights in
149
                    position embeddings will be sliced to fit the new length.
150
    :return: model
151
    """
152
    model, config = build_model_from_config(
153
        config_file, training=training, trainable=trainable, seq_len=seq_len)
154
    load_model_weights_from_checkpoint(
155
        model, config, checkpoint_file, training=training)
156
    return model