[51873b]: / keras_bert / layers / conv.py

Download this file

19 lines (13 with data), 520 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
from tensorflow import keras
import tensorflow.keras.backend as K
class MaskedConv1D(keras.layers.Conv1D):
def __init__(self, **kwargs):
super(MaskedConv1D, self).__init__(**kwargs)
self.supports_masking = True
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs, mask=None):
if mask is not None:
mask = K.cast(mask, K.floatx())
inputs *= K.expand_dims(mask, axis=-1)
return super(MaskedConv1D, self).call(inputs)