Diff of /docproduct/bert.py [000000] .. [51873b]

Switch to unified view

a b/docproduct/bert.py
1
import json
2
3
import tensorflow as tf
4
from tensorflow import keras
5
import tensorflow.keras.backend as K
6
7
from keras_bert.keras_pos_embd import PositionEmbedding
8
from keras_bert.layers import get_inputs, get_embedding, TokenEmbedding, EmbeddingSimilarity, Masked, Extract
9
from keras_bert.keras_layer_normalization import LayerNormalization
10
from keras_bert.keras_multi_head import MultiHeadAttention
11
from keras_bert.keras_position_wise_feed_forward import FeedForward
12
13
def gelu(x):
14
    return 0.5 * x * (1.0 + tf.math.erf(x / tf.sqrt(2.0)))
15
16
class Bert(keras.Model):
17
    def __init__(
18
            self,
19
            token_num,
20
            pos_num=512,
21
            seq_len=512,
22
            embed_dim=768,
23
            transformer_num=12,
24
            head_num=12,
25
            feed_forward_dim=3072,
26
            dropout_rate=0.1,
27
            attention_activation=None,
28
            feed_forward_activation=gelu,
29
            custom_layers=None,
30
            training=True,
31
            trainable=None,
32
            lr=1e-4,
33
            name='Bert'):
34
        super().__init__(name=name)
35
        self.token_num = token_num
36
        self.pos_num = pos_num
37
        self.seq_len = seq_len
38
        self.embed_dim = embed_dim
39
        self.transformer_num = transformer_num
40
        self.head_num = head_num
41
        self.feed_forward_dim = feed_forward_dim
42
        self.dropout_rate = dropout_rate
43
        self.attention_activation = attention_activation
44
        self.feed_forward_activation = feed_forward_activation
45
        self.custom_layers = custom_layers
46
        self.training = training
47
        self.trainable = trainable
48
        self.lr = lr
49
50
        # build layers
51
        # embedding
52
        self.token_embedding_layer = TokenEmbedding(
53
            input_dim=token_num,
54
            output_dim=embed_dim,
55
            mask_zero=True,
56
            trainable=trainable,
57
            name='Embedding-Token',
58
        )
59
        self.segment_embedding_layer = keras.layers.Embedding(
60
            input_dim=2,
61
            output_dim=embed_dim,
62
            trainable=trainable,
63
            name='Embedding-Segment',
64
        )
65
        self.position_embedding_layer = PositionEmbedding(
66
            input_dim=pos_num,
67
            output_dim=embed_dim,
68
            mode=PositionEmbedding.MODE_ADD,
69
            trainable=trainable,
70
            name='Embedding-Position',
71
        )
72
        self.embedding_layer_norm = LayerNormalization(
73
            trainable=trainable,
74
            name='Embedding-Norm',
75
        )
76
77
        self.encoder_multihead_layers = []
78
        self.encoder_ffn_layers = []
79
        self.encoder_attention_norm = []
80
        self.encoder_ffn_norm = []
81
        # attention layers
82
        for i in range(transformer_num):
83
            base_name = 'Encoder-%d' % (i + 1)
84
            attention_name = '%s-MultiHeadSelfAttention' % base_name
85
            feed_forward_name = '%s-FeedForward' % base_name
86
            self.encoder_multihead_layers.append(MultiHeadAttention(
87
                head_num=head_num,
88
                activation=attention_activation,
89
                history_only=False,
90
                trainable=trainable,
91
                name=attention_name,
92
            ))
93
            self.encoder_ffn_layers.append(FeedForward(
94
                units=feed_forward_dim,
95
                activation=feed_forward_activation,
96
                trainable=trainable,
97
                name=feed_forward_name,
98
            ))
99
            self.encoder_attention_norm.append(LayerNormalization(
100
                trainable=trainable,
101
                name='%s-Norm' % attention_name,
102
            ))
103
            self.encoder_ffn_norm.append(LayerNormalization(
104
                trainable=trainable,
105
                name='%s-Norm' % feed_forward_name,
106
            ))
107
108
    def call(self, inputs):
109
110
        embeddings = [
111
            self.token_embedding_layer(inputs[0]),
112
            self.segment_embedding_layer(inputs[1])
113
        ]
114
        embeddings[0], embed_weights = embeddings[0]
115
        embed_layer = keras.layers.Add(
116
            name='Embedding-Token-Segment')(embeddings)
117
        embed_layer = self.position_embedding_layer(embed_layer)
118
119
        if self.dropout_rate > 0.0:
120
            dropout_layer = keras.layers.Dropout(
121
                rate=self.dropout_rate,
122
                name='Embedding-Dropout',
123
            )(embed_layer)
124
        else:
125
            dropout_layer = embed_layer
126
127
        embedding_output = self.embedding_layer_norm(dropout_layer)
128
129
        def _wrap_layer(name, input_layer, build_func, norm_layer, dropout_rate=0.0, trainable=True):
130
            """Wrap layers with residual, normalization and dropout.
131
132
            :param name: Prefix of names for internal layers.
133
            :param input_layer: Input layer.
134
            :param build_func: A callable that takes the input tensor and generates the output tensor.
135
            :param dropout_rate: Dropout rate.
136
            :param trainable: Whether the layers are trainable.
137
            :return: Output layer.
138
            """
139
            build_output = build_func(input_layer)
140
            if dropout_rate > 0.0:
141
                dropout_layer = keras.layers.Dropout(
142
                    rate=dropout_rate,
143
                    name='%s-Dropout' % name,
144
                )(build_output)
145
            else:
146
                dropout_layer = build_output
147
            if isinstance(input_layer, list):
148
                input_layer = input_layer[0]
149
            add_layer = keras.layers.Add(name='%s-Add' %
150
                                         name)([input_layer, dropout_layer])
151
            normal_layer = norm_layer(add_layer)
152
            return normal_layer
153
154
        last_layer = embedding_output
155
        output_tensor_list = [last_layer]
156
        # self attention
157
        for i in range(self.transformer_num):
158
            base_name = 'Encoder-%d' % (i + 1)
159
            attention_name = '%s-MultiHeadSelfAttention' % base_name
160
            feed_forward_name = '%s-FeedForward' % base_name
161
            self_attention_output = _wrap_layer(
162
                name=attention_name,
163
                input_layer=last_layer,
164
                build_func=self.encoder_multihead_layers[i],
165
                norm_layer=self.encoder_attention_norm[i],
166
                dropout_rate=self.dropout_rate,
167
                trainable=self.trainable)
168
            last_layer = _wrap_layer(
169
                name=attention_name,
170
                input_layer=self_attention_output,
171
                build_func=self.encoder_ffn_layers[i],
172
                norm_layer=self.encoder_ffn_norm[i],
173
                dropout_rate=self.dropout_rate,
174
                trainable=self.trainable)
175
            output_tensor_list.append(last_layer)
176
177
        return output_tensor_list
178
179
180
def build_model_from_config(config_file,
181
                            training=False,
182
                            trainable=None,
183
                            seq_len=None,
184
                            build=True):
185
    """Build the model from config file.
186
    :param config_file: The path to the JSON configuration file.
187
    :param training: If training, the whole model will be returned.
188
    :param trainable: Whether the model is trainable.
189
    :param seq_len: If it is not None and it is shorter than the value in the config file, the weights in
190
                    position embeddings will be sliced to fit the new length.
191
    :return: model and config
192
    """
193
    with open(config_file, 'r') as reader:
194
        config = json.loads(reader.read())
195
    if seq_len is not None:
196
        config['max_position_embeddings'] = min(
197
            seq_len, config['max_position_embeddings'])
198
    if trainable is None:
199
        trainable = training
200
    model = Bert(
201
        token_num=config['vocab_size'],
202
        pos_num=config['max_position_embeddings'],
203
        seq_len=config['max_position_embeddings'],
204
        embed_dim=config['hidden_size'],
205
        transformer_num=config['num_hidden_layers'],
206
        head_num=config['num_attention_heads'],
207
        feed_forward_dim=config['intermediate_size'],
208
        training=training,
209
        trainable=trainable,
210
    )
211
    if build:
212
        model.build(input_shape=[(None, None), (None, None), (None, None)])
213
    return model, config