[51873b]: / keras_bert / layers / masked.py

Download this file

45 lines (34 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from tensorflow import keras
import tensorflow.keras.backend as K
class Masked(keras.layers.Layer):
"""Generate output mask based on the given mask.
The inputs for the layer is the original input layer and the masked locations.
See: https://arxiv.org/pdf/1810.04805.pdf
"""
def __init__(self,
return_masked=False,
**kwargs):
"""Initialize the layer.
:param return_masked: Whether to return the merged mask.
:param kwargs: Arguments for parent class.
"""
super(Masked, self).__init__(**kwargs)
self.supports_masking = True
self.return_masked = return_masked
def get_config(self):
config = {
'return_masked': self.return_masked,
}
base_config = super(Masked, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
if self.return_masked:
return [input_shape[0], input_shape[0][:-1]]
return input_shape[0]
def compute_mask(self, inputs, mask=None):
token_mask = K.not_equal(inputs[1], 0)
return K.all(K.stack([token_mask, mask[0]], axis=0), axis=0)
def call(self, inputs, mask=None, **kwargs):
if self.return_masked:
return [inputs[0], K.cast(self.compute_mask(inputs, mask), K.floatx())]
return inputs[0]