[51873b]: / keras_bert / layers / extract.py

Download this file

30 lines (21 with data), 751 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from tensorflow import keras
class Extract(keras.layers.Layer):
"""Extract from index.
See: https://arxiv.org/pdf/1810.04805.pdf
"""
def __init__(self, index, **kwargs):
super(Extract, self).__init__(**kwargs)
self.index = index
self.supports_masking = True
def get_config(self):
config = {
'index': self.index,
}
base_config = super(Extract, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape[:1] + input_shape[2:]
def compute_mask(self, inputs, mask=None):
return None
def call(self, x, mask=None):
return x[:, self.index]