[51873b]: / docproduct / models.py

Download this file

127 lines (106 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow import keras
from docproduct.bert import build_model_from_config
from keras_bert.loader import load_model_weights_from_checkpoint
class FFN(tf.keras.layers.Layer):
def __init__(
self,
hidden_size=768,
dropout=0.2,
residual=True,
name='FFN',
**kwargs):
"""Simple Dense wrapped with various layers
"""
super(FFN, self).__init__(name=name, **kwargs)
self.hidden_size = hidden_size
self.dropout = dropout
self.residual = residual
self.ffn_layer = tf.keras.layers.Dense(
units=hidden_size,
use_bias=True
)
def call(self, inputs):
ffn_embedding = self.ffn_layer(inputs)
ffn_embedding = tf.keras.layers.ReLU()(ffn_embedding)
if self.dropout > 0:
ffn_embedding = tf.keras.layers.Dropout(
self.dropout)(ffn_embedding)
if self.residual:
ffn_embedding += inputs
return ffn_embedding
class MedicalQAModel(tf.keras.Model):
def __init__(self, name=''):
super(MedicalQAModel, self).__init__(name=name)
self.q_ffn = FFN(name='q_ffn', input_shape=(768,))
self.a_ffn = FFN(name='a_ffn', input_shape=(768,))
def call(self, inputs):
q_bert_embedding, a_bert_embedding = tf.unstack(inputs, axis=1)
q_embedding, a_embedding = self.q_ffn(
q_bert_embedding), self.a_ffn(a_bert_embedding)
return tf.stack([q_embedding, a_embedding], axis=1)
class MedicalQAModelwithBert(tf.keras.Model):
def __init__(
self,
hidden_size=768,
dropout=0.2,
residual=True,
config_file=None,
checkpoint_file=None,
bert_trainable=True,
layer_ind=-1,
name=''):
super(MedicalQAModelwithBert, self).__init__(name=name)
build = checkpoint_file != None
self.biobert, config = build_model_from_config(
config_file=config_file,
training=False,
trainable=bert_trainable,
build=build)
if checkpoint_file is not None:
load_model_weights_from_checkpoint(
model=self.biobert, config=config, checkpoint_file=checkpoint_file, training=False)
self.q_ffn_layer = FFN(
hidden_size=hidden_size,
dropout=dropout,
residual=residual,
name='q_ffn')
self.a_ffn_layer = FFN(
hidden_size=hidden_size,
dropout=dropout,
residual=residual,
name='a_ffn')
self.layer_ind = layer_ind
def call(self, inputs):
if 'q_input_ids' in inputs:
with_question = True
else:
with_question = False
if 'a_input_ids' in inputs:
with_answer = True
else:
with_answer = False
# according to USE, the DAN network average embedding across tokens
if with_question:
q_bert_embedding = self.biobert(
(inputs['q_input_ids'], inputs['q_segment_ids'], inputs['q_input_masks']))[self.layer_ind]
q_bert_embedding = tf.reduce_mean(q_bert_embedding, axis=1)
if with_answer:
a_bert_embedding = self.biobert(
(inputs['a_input_ids'], inputs['a_segment_ids'], inputs['a_input_masks']))[self.layer_ind]
a_bert_embedding = tf.reduce_mean(a_bert_embedding, axis=1)
if with_question:
q_embedding = self.q_ffn_layer(q_bert_embedding)
output = q_embedding
if with_answer:
a_embedding = self.a_ffn_layer(a_bert_embedding)
output = a_embedding
if with_question and with_answer:
output = tf.stack([q_embedding, a_embedding], axis=1)
return output