[51873b]: / keras_bert / layers / inputs.py

Download this file

16 lines (11 with data), 344 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from tensorflow import keras
def get_inputs(seq_len):
"""Get input layers.
See: https://arxiv.org/pdf/1810.04805.pdf
:param seq_len: Length of the sequence or None.
"""
names = ['Token', 'Segment', 'Masked']
return [keras.layers.Input(
shape=(None,),
name='Input-%s' % name,
) for name in names]