import json
from tensorflow import keras
import numpy as np
import tensorflow as tf
from .bert import get_model
__all__ = [
'build_model_from_config',
'load_model_weights_from_checkpoint',
'load_trained_model_from_checkpoint',
]
def checkpoint_loader(checkpoint_file):
def _loader(name):
return tf.train.load_variable(checkpoint_file, name)
return _loader
def build_model_from_config(config_file,
training=False,
trainable=None,
seq_len=None):
"""Build the model from config file.
:param config_file: The path to the JSON configuration file.
:param training: If training, the whole model will be returned.
:param trainable: Whether the model is trainable.
:param seq_len: If it is not None and it is shorter than the value in the config file, the weights in
position embeddings will be sliced to fit the new length.
:return: model and config
"""
with open(config_file, 'r') as reader:
config = json.loads(reader.read())
if seq_len is not None:
config['max_position_embeddings'] = min(
seq_len, config['max_position_embeddings'])
if trainable is None:
trainable = training
model = get_model(
token_num=config['vocab_size'],
pos_num=config['max_position_embeddings'],
seq_len=config['max_position_embeddings'],
embed_dim=config['hidden_size'],
transformer_num=config['num_hidden_layers'],
head_num=config['num_attention_heads'],
feed_forward_dim=config['intermediate_size'],
training=training,
trainable=trainable,
)
model.build(input_shape=[(None, None), (None, None), (None, None)])
return model, config
def load_model_weights_from_checkpoint(model,
config,
checkpoint_file,
training=False):
"""Load trained official model from checkpoint.
:param model: Built keras model.
:param config: Loaded configuration file.
:param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'.
:param training: If training, the whole model will be returned.
Otherwise, the MLM and NSP parts will be ignored.
"""
loader = checkpoint_loader(checkpoint_file)
model.get_layer(name='Embedding-Token').set_weights([
loader('bert/embeddings/word_embeddings'),
])
model.get_layer(name='Embedding-Position').set_weights([
loader(
'bert/embeddings/position_embeddings')[:config['max_position_embeddings'], :],
])
model.get_layer(name='Embedding-Segment').set_weights([
loader('bert/embeddings/token_type_embeddings'),
])
model.get_layer(name='Embedding-Norm').set_weights([
loader('bert/embeddings/LayerNorm/gamma'),
loader('bert/embeddings/LayerNorm/beta'),
])
for i in range(config['num_hidden_layers']):
model.get_layer(name='Encoder-%d-MultiHeadSelfAttention' % (i + 1)).set_weights([
loader('bert/encoder/layer_%d/attention/self/query/kernel' % i),
loader('bert/encoder/layer_%d/attention/self/query/bias' % i),
loader('bert/encoder/layer_%d/attention/self/key/kernel' % i),
loader('bert/encoder/layer_%d/attention/self/key/bias' % i),
loader('bert/encoder/layer_%d/attention/self/value/kernel' % i),
loader('bert/encoder/layer_%d/attention/self/value/bias' % i),
loader('bert/encoder/layer_%d/attention/output/dense/kernel' % i),
loader('bert/encoder/layer_%d/attention/output/dense/bias' % i),
])
model.get_layer(name='Encoder-%d-MultiHeadSelfAttention-Norm' % (i + 1)).set_weights([
loader('bert/encoder/layer_%d/attention/output/LayerNorm/gamma' % i),
loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' % i),
])
model.get_layer(name='Encoder-%d-MultiHeadSelfAttention-Norm' % (i + 1)).set_weights([
loader('bert/encoder/layer_%d/attention/output/LayerNorm/gamma' % i),
loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' % i),
])
model.get_layer(name='Encoder-%d-FeedForward' % (i + 1)).set_weights([
loader('bert/encoder/layer_%d/intermediate/dense/kernel' % i),
loader('bert/encoder/layer_%d/intermediate/dense/bias' % i),
loader('bert/encoder/layer_%d/output/dense/kernel' % i),
loader('bert/encoder/layer_%d/output/dense/bias' % i),
])
model.get_layer(name='Encoder-%d-FeedForward-Norm' % (i + 1)).set_weights([
loader('bert/encoder/layer_%d/output/LayerNorm/gamma' % i),
loader('bert/encoder/layer_%d/output/LayerNorm/beta' % i),
])
if training:
model.get_layer(name='MLM-Dense').set_weights([
loader('cls/predictions/transform/dense/kernel'),
loader('cls/predictions/transform/dense/bias'),
])
model.get_layer(name='MLM-Norm').set_weights([
loader('cls/predictions/transform/LayerNorm/gamma'),
loader('cls/predictions/transform/LayerNorm/beta'),
])
model.get_layer(name='MLM-Sim').set_weights([
loader('cls/predictions/output_bias'),
])
model.get_layer(name='NSP-Dense').set_weights([
loader('bert/pooler/dense/kernel'),
loader('bert/pooler/dense/bias'),
])
model.get_layer(name='NSP').set_weights([
np.transpose(loader('cls/seq_relationship/output_weights')),
loader('cls/seq_relationship/output_bias'),
])
def load_trained_model_from_checkpoint(config_file,
checkpoint_file,
training=False,
trainable=None,
seq_len=None):
"""Load trained official model from checkpoint.
:param config_file: The path to the JSON configuration file.
:param checkpoint_file: The path to the checkpoint files, should end with '.ckpt'.
:param training: If training, the whole model will be returned.
Otherwise, the MLM and NSP parts will be ignored.
:param trainable: Whether the model is trainable. The default value is the same with `training`.
:param seq_len: If it is not None and it is shorter than the value in the config file, the weights in
position embeddings will be sliced to fit the new length.
:return: model
"""
model, config = build_model_from_config(
config_file, training=training, trainable=trainable, seq_len=seq_len)
load_model_weights_from_checkpoint(
model, config, checkpoint_file, training=training)
return model