Switch to unified view

a b/docproduct/train_bertffn.py
1
import argparse
2
import os
3
4
import tensorflow as tf
5
import tensorflow.keras.backend as K
6
7
from docproduct.dataset import create_dataset_for_bert
8
from docproduct.models import MedicalQAModelwithBert
9
from docproduct.loss import qa_pair_loss, qa_pair_cross_entropy_loss
10
from docproduct.tokenization import FullTokenizer
11
from docproduct.metrics import qa_pair_batch_accuracy
12
13
14
def train_bertffn(model_path='models/bertffn_crossentropy/bertffn',
15
                  data_path='data/mqa_csv',
16
                  num_epochs=20,
17
                  num_gpu=1,
18
                  batch_size=64,
19
                  learning_rate=2e-5,
20
                  validation_split=0.2,
21
                  loss='categorical_crossentropy',
22
                  pretrained_path='models/pubmed_pmc_470k/',
23
                  max_seq_len=256):
24
    """A function to train BertFFNN similarity embedding model.
25
26
    Input file format:
27
        question,answer
28
        my eyes hurts, go see a doctor
29
30
    For more information about training details:
31
    https://github.com/Santosh-Gupta/DocProduct/blob/master/README.md
32
33
    Keyword Arguments:
34
        model_path {str} -- Path to save embedding model weights, ends with prefix of model files (default: {'models/bertffn_crossentropy/bertffn'})
35
        data_path {str} -- CSV data path (default: {'data/mqa_csv'})
36
        num_epochs {int} -- Number of Epochs to train (default: {20})
37
        num_gpu {int} -- Number of GPU to use(Currently only support single GPU) (default: {1})
38
        batch_size {int} -- Batch size (default: {64})
39
        learning_rate {float} -- learning rate (default: {2e-5})
40
        validation_split {float} -- validation split (default: {0.2})
41
        loss {str} -- Loss type, either MSE or crossentropy (default: {'categorical_crossentropy'})
42
        pretrained_path {str} -- Pretrained bioBert model path (default: {'models/pubmed_pmc_470k/'})
43
        max_seq_len {int} -- Max sequence length of model(No effects if dynamic padding is enabled) (default: {256})
44
    """
45
    tf.compat.v1.disable_eager_execution()
46
    if loss == 'categorical_crossentropy':
47
        loss_fn = qa_pair_cross_entropy_loss
48
    else:
49
        loss_fn = qa_pair_loss
50
    K.set_floatx('float32')
51
    tokenizer = FullTokenizer(os.path.join(pretrained_path, 'vocab.txt'))
52
    d = create_dataset_for_bert(
53
        data_path, tokenizer=tokenizer, batch_size=batch_size,
54
        shuffle_buffer=500000, dynamic_padding=True, max_seq_length=max_seq_len)
55
    eval_d = create_dataset_for_bert(
56
        data_path, tokenizer=tokenizer, batch_size=batch_size,
57
        mode='eval', dynamic_padding=True, max_seq_length=max_seq_len,
58
        bucket_batch_sizes=[64, 64, 64])
59
60
    medical_qa_model = MedicalQAModelwithBert(
61
        config_file=os.path.join(
62
            pretrained_path, 'bert_config.json'),
63
        checkpoint_file=os.path.join(pretrained_path, 'biobert_model.ckpt'))
64
    optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
65
    medical_qa_model.compile(
66
        optimizer=optimizer, loss=loss_fn, metrics=[qa_pair_batch_accuracy])
67
68
    epochs = num_epochs
69
70
    callback = tf.keras.callbacks.ModelCheckpoint(
71
        model_path, verbose=1, save_weights_only=True, save_best_only=False, period=1)
72
73
    medical_qa_model.fit(d, epochs=epochs, callbacks=[callback])
74
    medical_qa_model.summary()
75
    medical_qa_model.save_weights(model_path)
76
    medical_qa_model.evaluate(eval_d)
77
78
79
if __name__ == "__main__":
80
81
    train_bertffn()