|
a |
|
b/docproduct/train_ffn.py |
|
|
1 |
import argparse |
|
|
2 |
|
|
|
3 |
import tensorflow as tf |
|
|
4 |
import tensorflow.keras.backend as K |
|
|
5 |
|
|
|
6 |
from docproduct.dataset import create_dataset_for_ffn |
|
|
7 |
from docproduct.models import MedicalQAModel |
|
|
8 |
from docproduct.loss import qa_pair_loss, qa_pair_cross_entropy_loss |
|
|
9 |
from docproduct.metrics import qa_pair_batch_accuracy |
|
|
10 |
|
|
|
11 |
DEVICE = ["/gpu:0", "/gpu:1"] |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def multi_gpu_train(batch_size, num_gpu, data_path, num_epochs, model_path, loss=qa_pair_loss): |
|
|
15 |
mirrored_strategy = tf.distribute.MirroredStrategy( |
|
|
16 |
devices=DEVICE[:num_gpu]) |
|
|
17 |
global_batch_size = batch_size*num_gpu |
|
|
18 |
learning_rate = learning_rate*1.5**num_gpu |
|
|
19 |
with mirrored_strategy.scope(): |
|
|
20 |
d = create_dataset_for_ffn( |
|
|
21 |
data_path, batch_size=global_batch_size, shuffle_buffer=100000) |
|
|
22 |
|
|
|
23 |
d_iter = mirrored_strategy.make_dataset_iterator(d) |
|
|
24 |
|
|
|
25 |
medical_qa_model = tf.keras.Sequential() |
|
|
26 |
medical_qa_model.add(tf.keras.layers.Input((2, 768))) |
|
|
27 |
medical_qa_model.add(MedicalQAModel()) |
|
|
28 |
optimizer = tf.keras.optimizers.Adam(lr=learning_rate) |
|
|
29 |
medical_qa_model.compile( |
|
|
30 |
optimizer=optimizer, loss=loss) |
|
|
31 |
|
|
|
32 |
epochs = num_epochs |
|
|
33 |
loss_metric = tf.keras.metrics.Mean() |
|
|
34 |
|
|
|
35 |
medical_qa_model.fit(d_iter, epochs=epochs, metrics=[ |
|
|
36 |
qa_pair_batch_accuracy]) |
|
|
37 |
medical_qa_model.save_weights(model_path) |
|
|
38 |
return medical_qa_model |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def single_gpu_train(batch_size, num_gpu, data_path, num_epochs, model_path, loss=qa_pair_loss): |
|
|
42 |
global_batch_size = batch_size*num_gpu |
|
|
43 |
learning_rate = learning_rate |
|
|
44 |
d = create_dataset_for_ffn( |
|
|
45 |
data_path, batch_size=global_batch_size, shuffle_buffer=500000) |
|
|
46 |
eval_d = create_dataset_for_ffn( |
|
|
47 |
data_path, batch_size=batch_size, mode='eval') |
|
|
48 |
|
|
|
49 |
medical_qa_model = MedicalQAModel() |
|
|
50 |
optimizer = tf.keras.optimizers.Adam(lr=learning_rate) |
|
|
51 |
medical_qa_model.compile( |
|
|
52 |
optimizer=optimizer, loss=loss, metrics=[ |
|
|
53 |
qa_pair_batch_accuracy]) |
|
|
54 |
|
|
|
55 |
epochs = num_epochs |
|
|
56 |
|
|
|
57 |
medical_qa_model.fit(d, epochs=epochs, validation_data=eval_d) |
|
|
58 |
medical_qa_model.save_weights(model_path) |
|
|
59 |
return medical_qa_model |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def train_ffn(model_path='models/ffn_crossentropy/ffn', |
|
|
63 |
data_path='data/mqa_csv', |
|
|
64 |
num_epochs=300, |
|
|
65 |
num_gpu=1, |
|
|
66 |
batch_size=64, |
|
|
67 |
learning_rate=0.0001, |
|
|
68 |
validation_split=0.2, |
|
|
69 |
loss='categorical_crossentropy'): |
|
|
70 |
|
|
|
71 |
if loss == 'categorical_crossentropy': |
|
|
72 |
loss_fn = qa_pair_cross_entropy_loss |
|
|
73 |
else: |
|
|
74 |
loss_fn = qa_pair_loss |
|
|
75 |
eval_d = create_dataset_for_ffn( |
|
|
76 |
data_path, batch_size=batch_size, mode='eval') |
|
|
77 |
|
|
|
78 |
if num_gpu > 1: |
|
|
79 |
medical_qa_model = multi_gpu_train( |
|
|
80 |
batch_size, num_gpu, data_path, num_epochs, model_path, loss_fn) |
|
|
81 |
else: |
|
|
82 |
medical_qa_model = single_gpu_train( |
|
|
83 |
batch_size, num_gpu, data_path, num_epochs, model_path, loss_fn) |
|
|
84 |
|
|
|
85 |
medical_qa_model.summary() |
|
|
86 |
medical_qa_model.save_weights(model_path, overwrite=True) |
|
|
87 |
# K.set_learning_phase(0) |
|
|
88 |
# q_embedding, a_embedding = tf.unstack( |
|
|
89 |
# medical_qa_model(next(iter(eval_d))[0]), axis=1) |
|
|
90 |
|
|
|
91 |
# q_embedding = q_embedding / tf.norm(q_embedding, axis=-1, keepdims=True) |
|
|
92 |
# a_embedding = a_embedding / tf.norm(a_embedding, axis=-1, keepdims=True) |
|
|
93 |
|
|
|
94 |
# batch_score = tf.reduce_sum(q_embedding*a_embedding, axis=-1) |
|
|
95 |
# baseline_score = tf.reduce_mean( |
|
|
96 |
# tf.matmul(q_embedding, tf.transpose(a_embedding))) |
|
|
97 |
|
|
|
98 |
# print('Eval Batch Cos similarity') |
|
|
99 |
# print(tf.reduce_mean(batch_score)) |
|
|
100 |
# print('Baseline: {0}'.format(baseline_score)) |
|
|
101 |
|
|
|
102 |
# medical_qa_model.save_weights(model_path, overwrite=True) |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
if __name__ == "__main__": |
|
|
106 |
|
|
|
107 |
train_ffn() |