a b/docproduct/train_ffn.py
1
import argparse
2
3
import tensorflow as tf
4
import tensorflow.keras.backend as K
5
6
from docproduct.dataset import create_dataset_for_ffn
7
from docproduct.models import MedicalQAModel
8
from docproduct.loss import qa_pair_loss, qa_pair_cross_entropy_loss
9
from docproduct.metrics import qa_pair_batch_accuracy
10
11
DEVICE = ["/gpu:0", "/gpu:1"]
12
13
14
def multi_gpu_train(batch_size, num_gpu, data_path, num_epochs, model_path, loss=qa_pair_loss):
15
    mirrored_strategy = tf.distribute.MirroredStrategy(
16
        devices=DEVICE[:num_gpu])
17
    global_batch_size = batch_size*num_gpu
18
    learning_rate = learning_rate*1.5**num_gpu
19
    with mirrored_strategy.scope():
20
        d = create_dataset_for_ffn(
21
            data_path, batch_size=global_batch_size, shuffle_buffer=100000)
22
23
        d_iter = mirrored_strategy.make_dataset_iterator(d)
24
25
        medical_qa_model = tf.keras.Sequential()
26
        medical_qa_model.add(tf.keras.layers.Input((2, 768)))
27
        medical_qa_model.add(MedicalQAModel())
28
        optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
29
        medical_qa_model.compile(
30
            optimizer=optimizer, loss=loss)
31
32
    epochs = num_epochs
33
    loss_metric = tf.keras.metrics.Mean()
34
35
    medical_qa_model.fit(d_iter, epochs=epochs, metrics=[
36
                         qa_pair_batch_accuracy])
37
    medical_qa_model.save_weights(model_path)
38
    return medical_qa_model
39
40
41
def single_gpu_train(batch_size, num_gpu, data_path, num_epochs, model_path, loss=qa_pair_loss):
42
    global_batch_size = batch_size*num_gpu
43
    learning_rate = learning_rate
44
    d = create_dataset_for_ffn(
45
        data_path, batch_size=global_batch_size, shuffle_buffer=500000)
46
    eval_d = create_dataset_for_ffn(
47
        data_path, batch_size=batch_size, mode='eval')
48
49
    medical_qa_model = MedicalQAModel()
50
    optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
51
    medical_qa_model.compile(
52
        optimizer=optimizer, loss=loss, metrics=[
53
            qa_pair_batch_accuracy])
54
55
    epochs = num_epochs
56
57
    medical_qa_model.fit(d, epochs=epochs, validation_data=eval_d)
58
    medical_qa_model.save_weights(model_path)
59
    return medical_qa_model
60
61
62
def train_ffn(model_path='models/ffn_crossentropy/ffn',
63
              data_path='data/mqa_csv',
64
              num_epochs=300,
65
              num_gpu=1,
66
              batch_size=64,
67
              learning_rate=0.0001,
68
              validation_split=0.2,
69
              loss='categorical_crossentropy'):
70
71
    if loss == 'categorical_crossentropy':
72
        loss_fn = qa_pair_cross_entropy_loss
73
    else:
74
        loss_fn = qa_pair_loss
75
    eval_d = create_dataset_for_ffn(
76
        data_path, batch_size=batch_size, mode='eval')
77
78
    if num_gpu > 1:
79
        medical_qa_model = multi_gpu_train(
80
            batch_size, num_gpu, data_path, num_epochs, model_path, loss_fn)
81
    else:
82
        medical_qa_model = single_gpu_train(
83
            batch_size, num_gpu, data_path, num_epochs, model_path, loss_fn)
84
85
    medical_qa_model.summary()
86
    medical_qa_model.save_weights(model_path, overwrite=True)
87
    # K.set_learning_phase(0)
88
    # q_embedding, a_embedding = tf.unstack(
89
    #     medical_qa_model(next(iter(eval_d))[0]), axis=1)
90
91
    # q_embedding = q_embedding / tf.norm(q_embedding, axis=-1, keepdims=True)
92
    # a_embedding = a_embedding / tf.norm(a_embedding, axis=-1, keepdims=True)
93
94
    # batch_score = tf.reduce_sum(q_embedding*a_embedding, axis=-1)
95
    # baseline_score = tf.reduce_mean(
96
    #     tf.matmul(q_embedding, tf.transpose(a_embedding)))
97
98
    # print('Eval Batch Cos similarity')
99
    # print(tf.reduce_mean(batch_score))
100
    # print('Baseline: {0}'.format(baseline_score))
101
102
    # medical_qa_model.save_weights(model_path, overwrite=True)
103
104
105
if __name__ == "__main__":
106
107
    train_ffn()