[51873b]: / keras_bert / layers / pooling.py

Download this file

22 lines (15 with data), 638 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from tensorflow import keras
import tensorflow.keras.backend as K
class MaskedGlobalMaxPool1D(keras.layers.Layer):
def __init__(self, **kwargs):
super(MaskedGlobalMaxPool1D, self).__init__(**kwargs)
self.supports_masking = True
def compute_mask(self, inputs, mask=None):
return None
def compute_output_shape(self, input_shape):
return input_shape[:-2] + (input_shape[-1],)
def call(self, inputs, mask=None):
if mask is not None:
mask = K.cast(mask, K.floatx())
inputs -= K.expand_dims((1.0 - mask) * 1e6, axis=-1)
return K.max(inputs, axis=-2)