a b/docproduct/models.py
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
3
import os
4
import pandas as pd
5
from sklearn.model_selection import train_test_split
6
import numpy as np
7
8
import tensorflow as tf
9
import tensorflow.keras.backend as K
10
from tensorflow import keras
11
12
from docproduct.bert import build_model_from_config
13
14
from keras_bert.loader import load_model_weights_from_checkpoint
15
16
17
class FFN(tf.keras.layers.Layer):
18
    def __init__(
19
            self,
20
            hidden_size=768,
21
            dropout=0.2,
22
            residual=True,
23
            name='FFN',
24
            **kwargs):
25
        """Simple Dense wrapped with various layers
26
        """
27
28
        super(FFN, self).__init__(name=name, **kwargs)
29
        self.hidden_size = hidden_size
30
        self.dropout = dropout
31
        self.residual = residual
32
        self.ffn_layer = tf.keras.layers.Dense(
33
            units=hidden_size,
34
            use_bias=True
35
        )
36
37
    def call(self, inputs):
38
        ffn_embedding = self.ffn_layer(inputs)
39
        ffn_embedding = tf.keras.layers.ReLU()(ffn_embedding)
40
        if self.dropout > 0:
41
            ffn_embedding = tf.keras.layers.Dropout(
42
                self.dropout)(ffn_embedding)
43
44
        if self.residual:
45
            ffn_embedding += inputs
46
        return ffn_embedding
47
48
49
class MedicalQAModel(tf.keras.Model):
50
    def __init__(self, name=''):
51
        super(MedicalQAModel, self).__init__(name=name)
52
        self.q_ffn = FFN(name='q_ffn', input_shape=(768,))
53
        self.a_ffn = FFN(name='a_ffn', input_shape=(768,))
54
55
    def call(self, inputs):
56
        q_bert_embedding, a_bert_embedding = tf.unstack(inputs, axis=1)
57
        q_embedding, a_embedding = self.q_ffn(
58
            q_bert_embedding), self.a_ffn(a_bert_embedding)
59
        return tf.stack([q_embedding, a_embedding], axis=1)
60
61
62
class MedicalQAModelwithBert(tf.keras.Model):
63
    def __init__(
64
            self,
65
            hidden_size=768,
66
            dropout=0.2,
67
            residual=True,
68
            config_file=None,
69
            checkpoint_file=None,
70
            bert_trainable=True,
71
            layer_ind=-1,
72
            name=''):
73
        super(MedicalQAModelwithBert, self).__init__(name=name)
74
        build = checkpoint_file != None
75
        self.biobert, config = build_model_from_config(
76
            config_file=config_file,
77
            training=False,
78
            trainable=bert_trainable,
79
            build=build)
80
        if checkpoint_file is not None:
81
            load_model_weights_from_checkpoint(
82
                model=self.biobert, config=config, checkpoint_file=checkpoint_file, training=False)
83
        self.q_ffn_layer = FFN(
84
            hidden_size=hidden_size,
85
            dropout=dropout,
86
            residual=residual,
87
            name='q_ffn')
88
        self.a_ffn_layer = FFN(
89
            hidden_size=hidden_size,
90
            dropout=dropout,
91
            residual=residual,
92
            name='a_ffn')
93
        self.layer_ind = layer_ind
94
95
    def call(self, inputs):
96
97
        if 'q_input_ids' in inputs:
98
            with_question = True
99
        else:
100
            with_question = False
101
102
        if 'a_input_ids' in inputs:
103
            with_answer = True
104
        else:
105
            with_answer = False
106
        # according to USE, the DAN network average embedding across tokens
107
        if with_question:
108
            q_bert_embedding = self.biobert(
109
                (inputs['q_input_ids'], inputs['q_segment_ids'], inputs['q_input_masks']))[self.layer_ind]
110
            q_bert_embedding = tf.reduce_mean(q_bert_embedding, axis=1)
111
        if with_answer:
112
            a_bert_embedding = self.biobert(
113
                (inputs['a_input_ids'], inputs['a_segment_ids'], inputs['a_input_masks']))[self.layer_ind]
114
            a_bert_embedding = tf.reduce_mean(a_bert_embedding, axis=1)
115
116
        if with_question:
117
            q_embedding = self.q_ffn_layer(q_bert_embedding)
118
            output = q_embedding
119
        if with_answer:
120
            a_embedding = self.a_ffn_layer(a_bert_embedding)
121
            output = a_embedding
122
123
        if with_question and with_answer:
124
            output = tf.stack([q_embedding, a_embedding], axis=1)
125
126
        return output