--- a +++ b/docproduct/models.py @@ -0,0 +1,126 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import pandas as pd +from sklearn.model_selection import train_test_split +import numpy as np + +import tensorflow as tf +import tensorflow.keras.backend as K +from tensorflow import keras + +from docproduct.bert import build_model_from_config + +from keras_bert.loader import load_model_weights_from_checkpoint + + +class FFN(tf.keras.layers.Layer): + def __init__( + self, + hidden_size=768, + dropout=0.2, + residual=True, + name='FFN', + **kwargs): + """Simple Dense wrapped with various layers + """ + + super(FFN, self).__init__(name=name, **kwargs) + self.hidden_size = hidden_size + self.dropout = dropout + self.residual = residual + self.ffn_layer = tf.keras.layers.Dense( + units=hidden_size, + use_bias=True + ) + + def call(self, inputs): + ffn_embedding = self.ffn_layer(inputs) + ffn_embedding = tf.keras.layers.ReLU()(ffn_embedding) + if self.dropout > 0: + ffn_embedding = tf.keras.layers.Dropout( + self.dropout)(ffn_embedding) + + if self.residual: + ffn_embedding += inputs + return ffn_embedding + + +class MedicalQAModel(tf.keras.Model): + def __init__(self, name=''): + super(MedicalQAModel, self).__init__(name=name) + self.q_ffn = FFN(name='q_ffn', input_shape=(768,)) + self.a_ffn = FFN(name='a_ffn', input_shape=(768,)) + + def call(self, inputs): + q_bert_embedding, a_bert_embedding = tf.unstack(inputs, axis=1) + q_embedding, a_embedding = self.q_ffn( + q_bert_embedding), self.a_ffn(a_bert_embedding) + return tf.stack([q_embedding, a_embedding], axis=1) + + +class MedicalQAModelwithBert(tf.keras.Model): + def __init__( + self, + hidden_size=768, + dropout=0.2, + residual=True, + config_file=None, + checkpoint_file=None, + bert_trainable=True, + layer_ind=-1, + name=''): + super(MedicalQAModelwithBert, self).__init__(name=name) + build = checkpoint_file != None + self.biobert, config = build_model_from_config( + config_file=config_file, + training=False, + trainable=bert_trainable, + build=build) + if checkpoint_file is not None: + load_model_weights_from_checkpoint( + model=self.biobert, config=config, checkpoint_file=checkpoint_file, training=False) + self.q_ffn_layer = FFN( + hidden_size=hidden_size, + dropout=dropout, + residual=residual, + name='q_ffn') + self.a_ffn_layer = FFN( + hidden_size=hidden_size, + dropout=dropout, + residual=residual, + name='a_ffn') + self.layer_ind = layer_ind + + def call(self, inputs): + + if 'q_input_ids' in inputs: + with_question = True + else: + with_question = False + + if 'a_input_ids' in inputs: + with_answer = True + else: + with_answer = False + # according to USE, the DAN network average embedding across tokens + if with_question: + q_bert_embedding = self.biobert( + (inputs['q_input_ids'], inputs['q_segment_ids'], inputs['q_input_masks']))[self.layer_ind] + q_bert_embedding = tf.reduce_mean(q_bert_embedding, axis=1) + if with_answer: + a_bert_embedding = self.biobert( + (inputs['a_input_ids'], inputs['a_segment_ids'], inputs['a_input_masks']))[self.layer_ind] + a_bert_embedding = tf.reduce_mean(a_bert_embedding, axis=1) + + if with_question: + q_embedding = self.q_ffn_layer(q_bert_embedding) + output = q_embedding + if with_answer: + a_embedding = self.a_ffn_layer(a_bert_embedding) + output = a_embedding + + if with_question and with_answer: + output = tf.stack([q_embedding, a_embedding], axis=1) + + return output