|
a |
|
b/keras_bert/README.md |
|
|
1 |
# BERT |
|
|
2 |
|
|
|
3 |
Keras implementation of BERT modified for compatibility with TensorFlow 2.0 |
|
|
4 |
|
|
|
5 |
For extracting latent embeddings from medical question/answer data |
|
|
6 |
|
|
|
7 |
 |
|
|
8 |
|
|
|
9 |
# Acknowledgement |
|
|
10 |
|
|
|
11 |
Based on [CyberZHG's Keras BERT implementation](https://github.com/CyberZHG/keras-bert) |
|
|
12 |
|
|
|
13 |
# Usage |
|
|
14 |
|
|
|
15 |
### Tokenizer |
|
|
16 |
|
|
|
17 |
Splits text and generates indices: |
|
|
18 |
|
|
|
19 |
```python |
|
|
20 |
from keras_bert import Tokenizer |
|
|
21 |
|
|
|
22 |
token_dict = { |
|
|
23 |
'[CLS]': 0, |
|
|
24 |
'[SEP]': 1, |
|
|
25 |
'un': 2, |
|
|
26 |
'##aff': 3, |
|
|
27 |
'##able': 4, |
|
|
28 |
'[UNK]': 5, |
|
|
29 |
} |
|
|
30 |
tokenizer = Tokenizer(token_dict) |
|
|
31 |
print(tokenizer.tokenize('unaffable')) # The result should be `['[CLS]', 'un', '##aff', '##able', '[SEP]']` |
|
|
32 |
indices, segments = tokenizer.encode('unaffable') |
|
|
33 |
print(indices) # Should be `[0, 2, 3, 4, 1]` |
|
|
34 |
print(segments) # Should be `[0, 0, 0, 0, 0]` |
|
|
35 |
|
|
|
36 |
print(tokenizer.tokenize(first='unaffable', second='钢')) |
|
|
37 |
# The result should be `['[CLS]', 'un', '##aff', '##able', '[SEP]', '钢', '[SEP]']` |
|
|
38 |
indices, segments = tokenizer.encode(first='unaffable', second='钢', max_len=10) |
|
|
39 |
print(indices) # Should be `[0, 2, 3, 4, 1, 5, 1, 0, 0, 0]` |
|
|
40 |
print(segments) # Should be `[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]` |
|
|
41 |
``` |
|
|
42 |
|
|
|
43 |
### Training |
|
|
44 |
|
|
|
45 |
```python |
|
|
46 |
from tensorflow import keras |
|
|
47 |
from keras_bert import get_base_dict, get_model, gen_batch_inputs |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
# A toy input example |
|
|
51 |
sentence_pairs = [ |
|
|
52 |
[['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']], |
|
|
53 |
[['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']], |
|
|
54 |
[['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']], |
|
|
55 |
] |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
# Build token dictionary |
|
|
59 |
token_dict = get_base_dict() # A dict that contains some special tokens |
|
|
60 |
for pairs in sentence_pairs: |
|
|
61 |
for token in pairs[0] + pairs[1]: |
|
|
62 |
if token not in token_dict: |
|
|
63 |
token_dict[token] = len(token_dict) |
|
|
64 |
token_list = list(token_dict.keys()) # Used for selecting a random word |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
# Build & train the model |
|
|
68 |
model = get_model( |
|
|
69 |
token_num=len(token_dict), |
|
|
70 |
head_num=5, |
|
|
71 |
transformer_num=12, |
|
|
72 |
embed_dim=25, |
|
|
73 |
feed_forward_dim=100, |
|
|
74 |
seq_len=20, |
|
|
75 |
pos_num=20, |
|
|
76 |
dropout_rate=0.05, |
|
|
77 |
) |
|
|
78 |
model.summary() |
|
|
79 |
|
|
|
80 |
def _generator(): |
|
|
81 |
while True: |
|
|
82 |
yield gen_batch_inputs( |
|
|
83 |
sentence_pairs, |
|
|
84 |
token_dict, |
|
|
85 |
token_list, |
|
|
86 |
seq_len=20, |
|
|
87 |
mask_rate=0.3, |
|
|
88 |
swap_sentence_rate=1.0, |
|
|
89 |
) |
|
|
90 |
|
|
|
91 |
model.fit_generator( |
|
|
92 |
generator=_generator(), |
|
|
93 |
steps_per_epoch=1000, |
|
|
94 |
epochs=100, |
|
|
95 |
validation_data=_generator(), |
|
|
96 |
validation_steps=100, |
|
|
97 |
callbacks=[ |
|
|
98 |
keras.callbacks.EarlyStopping(monitor='val_loss', patience=5) |
|
|
99 |
], |
|
|
100 |
) |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
# Use the trained model |
|
|
104 |
inputs, output_layer = get_model( # `output_layer` is the last feature extraction layer (the last transformer) |
|
|
105 |
token_num=len(token_dict), |
|
|
106 |
head_num=5, |
|
|
107 |
transformer_num=12, |
|
|
108 |
embed_dim=25, |
|
|
109 |
feed_forward_dim=100, |
|
|
110 |
seq_len=20, |
|
|
111 |
pos_num=20, |
|
|
112 |
dropout_rate=0.05, |
|
|
113 |
training=False, # The input layers and output layer will be returned if `training` is `False` |
|
|
114 |
trainable=False, # Whether the model is trainable. The default value is the same with `training` |
|
|
115 |
) |
|
|
116 |
``` |
|
|
117 |
|
|
|
118 |
### Custom Feature Extraction |
|
|
119 |
|
|
|
120 |
```python |
|
|
121 |
def _custom_layers(x, trainable=True): |
|
|
122 |
return keras.layers.LSTM( |
|
|
123 |
units=768, |
|
|
124 |
trainable=trainable, |
|
|
125 |
return_sequences=True, |
|
|
126 |
name='LSTM', |
|
|
127 |
)(x) |
|
|
128 |
|
|
|
129 |
model = get_model( |
|
|
130 |
token_num=200, |
|
|
131 |
embed_dim=768, |
|
|
132 |
custom_layers=_custom_layers, |
|
|
133 |
) |
|
|
134 |
``` |