|
a |
|
b/docproduct/train_bertffn.py |
|
|
1 |
import argparse |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
import tensorflow as tf |
|
|
5 |
import tensorflow.keras.backend as K |
|
|
6 |
|
|
|
7 |
from docproduct.dataset import create_dataset_for_bert |
|
|
8 |
from docproduct.models import MedicalQAModelwithBert |
|
|
9 |
from docproduct.loss import qa_pair_loss, qa_pair_cross_entropy_loss |
|
|
10 |
from docproduct.tokenization import FullTokenizer |
|
|
11 |
from docproduct.metrics import qa_pair_batch_accuracy |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def train_bertffn(model_path='models/bertffn_crossentropy/bertffn', |
|
|
15 |
data_path='data/mqa_csv', |
|
|
16 |
num_epochs=20, |
|
|
17 |
num_gpu=1, |
|
|
18 |
batch_size=64, |
|
|
19 |
learning_rate=2e-5, |
|
|
20 |
validation_split=0.2, |
|
|
21 |
loss='categorical_crossentropy', |
|
|
22 |
pretrained_path='models/pubmed_pmc_470k/', |
|
|
23 |
max_seq_len=256): |
|
|
24 |
"""A function to train BertFFNN similarity embedding model. |
|
|
25 |
|
|
|
26 |
Input file format: |
|
|
27 |
question,answer |
|
|
28 |
my eyes hurts, go see a doctor |
|
|
29 |
|
|
|
30 |
For more information about training details: |
|
|
31 |
https://github.com/Santosh-Gupta/DocProduct/blob/master/README.md |
|
|
32 |
|
|
|
33 |
Keyword Arguments: |
|
|
34 |
model_path {str} -- Path to save embedding model weights, ends with prefix of model files (default: {'models/bertffn_crossentropy/bertffn'}) |
|
|
35 |
data_path {str} -- CSV data path (default: {'data/mqa_csv'}) |
|
|
36 |
num_epochs {int} -- Number of Epochs to train (default: {20}) |
|
|
37 |
num_gpu {int} -- Number of GPU to use(Currently only support single GPU) (default: {1}) |
|
|
38 |
batch_size {int} -- Batch size (default: {64}) |
|
|
39 |
learning_rate {float} -- learning rate (default: {2e-5}) |
|
|
40 |
validation_split {float} -- validation split (default: {0.2}) |
|
|
41 |
loss {str} -- Loss type, either MSE or crossentropy (default: {'categorical_crossentropy'}) |
|
|
42 |
pretrained_path {str} -- Pretrained bioBert model path (default: {'models/pubmed_pmc_470k/'}) |
|
|
43 |
max_seq_len {int} -- Max sequence length of model(No effects if dynamic padding is enabled) (default: {256}) |
|
|
44 |
""" |
|
|
45 |
tf.compat.v1.disable_eager_execution() |
|
|
46 |
if loss == 'categorical_crossentropy': |
|
|
47 |
loss_fn = qa_pair_cross_entropy_loss |
|
|
48 |
else: |
|
|
49 |
loss_fn = qa_pair_loss |
|
|
50 |
K.set_floatx('float32') |
|
|
51 |
tokenizer = FullTokenizer(os.path.join(pretrained_path, 'vocab.txt')) |
|
|
52 |
d = create_dataset_for_bert( |
|
|
53 |
data_path, tokenizer=tokenizer, batch_size=batch_size, |
|
|
54 |
shuffle_buffer=500000, dynamic_padding=True, max_seq_length=max_seq_len) |
|
|
55 |
eval_d = create_dataset_for_bert( |
|
|
56 |
data_path, tokenizer=tokenizer, batch_size=batch_size, |
|
|
57 |
mode='eval', dynamic_padding=True, max_seq_length=max_seq_len, |
|
|
58 |
bucket_batch_sizes=[64, 64, 64]) |
|
|
59 |
|
|
|
60 |
medical_qa_model = MedicalQAModelwithBert( |
|
|
61 |
config_file=os.path.join( |
|
|
62 |
pretrained_path, 'bert_config.json'), |
|
|
63 |
checkpoint_file=os.path.join(pretrained_path, 'biobert_model.ckpt')) |
|
|
64 |
optimizer = tf.keras.optimizers.Adam(lr=learning_rate) |
|
|
65 |
medical_qa_model.compile( |
|
|
66 |
optimizer=optimizer, loss=loss_fn, metrics=[qa_pair_batch_accuracy]) |
|
|
67 |
|
|
|
68 |
epochs = num_epochs |
|
|
69 |
|
|
|
70 |
callback = tf.keras.callbacks.ModelCheckpoint( |
|
|
71 |
model_path, verbose=1, save_weights_only=True, save_best_only=False, period=1) |
|
|
72 |
|
|
|
73 |
medical_qa_model.fit(d, epochs=epochs, callbacks=[callback]) |
|
|
74 |
medical_qa_model.summary() |
|
|
75 |
medical_qa_model.save_weights(model_path) |
|
|
76 |
medical_qa_model.evaluate(eval_d) |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
if __name__ == "__main__": |
|
|
80 |
|
|
|
81 |
train_bertffn() |