Diff of /docproduct/predictor.py [000000] .. [51873b]

Switch to unified view

a b/docproduct/predictor.py
1
import json
2
import os
3
import re
4
from collections import defaultdict
5
from multiprocessing import Pool, cpu_count
6
from time import time
7
8
import faiss
9
import numpy as np
10
import pandas as pd
11
import tensorflow as tf
12
from tqdm import tqdm
13
14
import gpt2_estimator
15
from docproduct.dataset import convert_text_to_feature
16
from docproduct.models import MedicalQAModelwithBert
17
from docproduct.tokenization import FullTokenizer
18
from keras_bert.loader import checkpoint_loader
19
20
21
def load_weight(model, bert_ffn_weight_file=None, ffn_weight_file=None):
22
    if bert_ffn_weight_file:
23
        model.load_weights(bert_ffn_weight_file)
24
    elif ffn_weight_file:
25
        loader = checkpoint_loader(ffn_weight_file)
26
        model.get_layer('q_ffn').set_weights(
27
            [loader('q_ffn/ffn_layer/kernel/.ATTRIBUTES/VARIABLE_VALUE'),
28
             loader('q_ffn/ffn_layer/bias/.ATTRIBUTES/VARIABLE_VALUE')])
29
        model.get_layer('a_ffn').set_weights(
30
            [loader('a_ffn/ffn_layer/kernel/.ATTRIBUTES/VARIABLE_VALUE'),
31
             loader('a_ffn/ffn_layer/bias/.ATTRIBUTES/VARIABLE_VALUE')]
32
        )
33
34
35
class QAEmbed(object):
36
    def __init__(
37
            self,
38
            hidden_size=768,
39
            dropout=0.2,
40
            residual=True,
41
            pretrained_path=None,
42
            batch_size=128,
43
            max_seq_length=256,
44
            ffn_weight_file=None,
45
            bert_ffn_weight_file=None,
46
            load_pretrain=True,
47
            with_question=True,
48
            with_answer=True):
49
        super(QAEmbed, self).__init__()
50
51
        config_file = os.path.join(pretrained_path, 'bert_config.json')
52
        if load_pretrain:
53
            checkpoint_file = os.path.join(
54
                pretrained_path, 'biobert_model.ckpt')
55
        else:
56
            checkpoint_file = None
57
58
        # the ffn model takes 2nd to last layer
59
        if bert_ffn_weight_file is None:
60
            layer_ind = -2
61
        else:
62
            layer_ind = -1
63
64
        self.model = MedicalQAModelwithBert(
65
            hidden_size=768,
66
            dropout=0.2,
67
            residual=True,
68
            config_file=config_file,
69
            checkpoint_file=checkpoint_file,
70
            layer_ind=layer_ind)
71
        self.batch_size = batch_size
72
        self.tokenizer = FullTokenizer(
73
            os.path.join(pretrained_path, 'vocab.txt'))
74
        self.max_seq_length = max_seq_length
75
76
        # build mode in order to load
77
        question = 'fake' if with_question else None
78
        answer = 'fake' if with_answer else None
79
        self.predict(questions=question, answers=answer, dataset=False)
80
        load_weight(self.model, bert_ffn_weight_file, ffn_weight_file)
81
82
    def _type_check(self, inputs):
83
        if inputs is not None:
84
            if isinstance(inputs, str):
85
                inputs = [inputs]
86
            elif isinstance(inputs, list):
87
                pass
88
            else:
89
                raise TypeError(
90
                    'inputs are supposed to be str of list of str, got {0} instead.'.format(type(inputs)))
91
            return inputs
92
93
    def _make_inputs(self, questions=None, answers=None, dataset=True):
94
95
        if questions:
96
            data_size = len(questions)
97
            q_feature_dict = defaultdict(list)
98
            for q in questions:
99
                q_feature = convert_text_to_feature(
100
                    q, tokenizer=self.tokenizer, max_seq_length=self.max_seq_length)
101
                q_feature_dict['q_input_ids'].append(q_feature[0])
102
                q_feature_dict['q_input_masks'].append(q_feature[1])
103
                q_feature_dict['q_segment_ids'].append(q_feature[2])
104
105
        if answers:
106
            data_size = len(answers)
107
            a_feature_dict = defaultdict(list)
108
            for a in answers:
109
                a_feature = convert_text_to_feature(
110
                    a, tokenizer=self.tokenizer, max_seq_length=self.max_seq_length)
111
                a_feature_dict['a_input_ids'].append(a_feature[0])
112
                a_feature_dict['a_input_masks'].append(a_feature[1])
113
                a_feature_dict['a_segment_ids'].append(a_feature[2])
114
115
        if questions and answers:
116
            q_feature_dict.update(a_feature_dict)
117
            model_inputs = q_feature_dict
118
        elif questions:
119
            model_inputs = q_feature_dict
120
        elif answers:
121
            model_inputs = a_feature_dict
122
123
        model_inputs = {k: tf.convert_to_tensor(
124
            np.stack(v, axis=0)) for k, v in model_inputs.items()}
125
        if dataset:
126
            model_inputs = tf.data.Dataset.from_tensor_slices(model_inputs)
127
            model_inputs = model_inputs.batch(self.batch_size)
128
129
        return model_inputs
130
131
    def predict(self, questions=None, answers=None, dataset=True):
132
133
        # type check
134
        questions = self._type_check(questions)
135
        answers = self._type_check(answers)
136
137
        if questions is not None and answers is not None:
138
            assert len(questions) == len(answers)
139
140
        model_inputs = self._make_inputs(questions, answers, dataset)
141
        model_outputs = []
142
143
        if dataset:
144
            for batch in tqdm(iter(model_inputs), total=int(len(questions) / self.batch_size)):
145
                model_outputs.append(self.model(batch))
146
            model_outputs = np.concatenate(model_outputs, axis=0)
147
        else:
148
            model_outputs = self.model(model_inputs)
149
        return model_outputs
150
151
152
class FaissTopK(object):
153
    def __init__(self, embedding_file):
154
        super(FaissTopK, self).__init__()
155
        self.embedding_file = embedding_file
156
        _, ext = os.path.splitext(self.embedding_file)
157
        if ext == '.pkl':
158
            self.df = pd.read_pickle(self.embedding_file)
159
        else:
160
            self.df = pd.read_parquet(self.embedding_file)
161
        self._get_faiss_index()
162
        # self.df.drop(columns=["Q_FFNN_embeds", "A_FFNN_embeds"], inplace=True)
163
164
    def _get_faiss_index(self):
165
        # with Pool(cpu_count()) as p:
166
        #     question_bert = p.map(eval, self.df["Q_FFNN_embeds"].tolist())
167
        #     answer_bert = p.map(eval, self.df["A_FFNN_embeds"].tolist())
168
        question_bert = self.df["Q_FFNN_embeds"].tolist()
169
        self.df.drop(columns=["Q_FFNN_embeds"], inplace=True)
170
        answer_bert = self.df["A_FFNN_embeds"].tolist()
171
        self.df.drop(columns=["A_FFNN_embeds"], inplace=True)
172
        question_bert = np.array(question_bert, dtype='float32')
173
        answer_bert = np.array(answer_bert, dtype='float32')
174
175
        self.answer_index = faiss.IndexFlatIP(answer_bert.shape[-1])
176
177
        self.question_index = faiss.IndexFlatIP(question_bert.shape[-1])
178
179
        self.answer_index.add(answer_bert)
180
        self.question_index.add(question_bert)
181
182
        del answer_bert, question_bert
183
184
    def predict(self, q_embedding, search_by='answer', topk=5, answer_only=True):
185
        if search_by == 'answer':
186
            _, index = self.answer_index.search(
187
                q_embedding.astype('float32'), topk)
188
        else:
189
            _, index = self.question_index.search(
190
                q_embedding.astype('float32'), topk)
191
192
        output_df = self.df.iloc[index[0], :]
193
        if answer_only:
194
            return output_df.answer.tolist()
195
        else:
196
            return (output_df.question.tolist(), output_df.answer.tolist())
197
198
199
class RetreiveQADoc(object):
200
    def __init__(self,
201
                 pretrained_path=None,
202
                 ffn_weight_file=None,
203
                 bert_ffn_weight_file='models/bertffn_crossentropy/bertffn',
204
                 embedding_file='qa_embeddings/bertffn_crossentropy.zip'
205
                 ):
206
        super(RetreiveQADoc, self).__init__()
207
        self.qa_embed = QAEmbed(
208
            pretrained_path=pretrained_path,
209
            ffn_weight_file=ffn_weight_file,
210
            bert_ffn_weight_file=bert_ffn_weight_file
211
        )
212
        self.faiss_topk = FaissTopK(embedding_file)
213
214
    def predict(self, questions, search_by='answer', topk=5, answer_only=True):
215
        embedding = self.qa_embed.predict(questions=questions)
216
        return self.faiss_topk.predict(embedding, search_by, topk, answer_only)
217
218
    def getEmbedding(self, questions, search_by='answer', topk=5, answer_only=True):
219
        embedding = self.qa_embed.predict(questions=questions)
220
        return embedding
221
222
223
class GenerateQADoc(object):
224
    def __init__(self,
225
                 pretrained_path='models/pubmed_pmc_470k/',
226
                 ffn_weight_file=None,
227
                 bert_ffn_weight_file='models/bertffn_crossentropy/bertffn',
228
                 gpt2_weight_file='models/gpt2',
229
                 embedding_file='qa_embeddings/bertffn_crossentropy.zip'
230
                 ):
231
        super(GenerateQADoc, self).__init__()
232
        tf.compat.v1.disable_eager_execution()
233
        session_config = tf.compat.v1.ConfigProto(
234
            allow_soft_placement=True)
235
        session_config.gpu_options.allow_growth = False
236
        config = tf.estimator.RunConfig(
237
            session_config=session_config)
238
        self.batch_size = 1
239
        self.gpt2_weight_file = gpt2_weight_file
240
        gpt2_model_fn = gpt2_estimator.get_gpt2_model_fn(
241
            accumulate_gradients=5,
242
            learning_rate=0.1,
243
            length=512,
244
            batch_size=self.batch_size,
245
            temperature=0.7,
246
            top_k=0
247
        )
248
        hparams = gpt2_estimator.default_hparams()
249
        with open(os.path.join(gpt2_weight_file, 'hparams.json')) as f:
250
            hparams.override_from_dict(json.load(f))
251
        self.estimator = tf.estimator.Estimator(
252
            gpt2_model_fn,
253
            model_dir=gpt2_weight_file,
254
            params=hparams,
255
            config=config)
256
        self.encoder = gpt2_estimator.encoder.get_encoder(gpt2_weight_file)
257
258
        config = tf.compat.v1.ConfigProto()
259
        config.gpu_options.allow_growth = True
260
        self.embed_sess = tf.compat.v1.Session(config=config)
261
        with self.embed_sess.as_default():
262
            self.qa_embed = QAEmbed(
263
                pretrained_path=pretrained_path,
264
                ffn_weight_file=ffn_weight_file,
265
                bert_ffn_weight_file=bert_ffn_weight_file,
266
                with_answer=False,
267
                load_pretrain=False
268
            )
269
270
        self.faiss_topk = FaissTopK(embedding_file)
271
272
    def _get_gpt2_inputs(self, question, questions, answers):
273
        assert len(questions) == len(answers)
274
        line = '`QUESTION: %s `ANSWER: ' % question
275
        for q, a in zip(questions, answers):
276
            line = '`QUESTION: %s `ANSWER: %s ' % (q, a) + line
277
        return line
278
279
    def predict(self, questions, search_by='answer', topk=5, answer_only=False):
280
        embedding = self.qa_embed.predict(
281
            questions=questions, dataset=False).eval(session=self.embed_sess)
282
        if answer_only:
283
            topk_answer = self.faiss_topk.predict(
284
                embedding, search_by, topk, answer_only)
285
        else:
286
            topk_question, topk_answer = self.faiss_topk.predict(
287
                embedding, search_by, topk, answer_only)
288
289
        gpt2_input = self._get_gpt2_inputs(
290
            questions[0], topk_question, topk_answer)
291
        gpt2_pred = self.estimator.predict(
292
            lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size, checkpoint_path=self.gpt2_weight_file))
293
        raw_output = gpt2_estimator.predictions_parsing(
294
            gpt2_pred, self.encoder)
295
        # result_list = [re.search('`ANSWER:(.*)`QUESTION:', s)
296
        #                for s in raw_output]
297
        # result_list = [s for s in result_list if s]
298
        # try:
299
        #     r = result_list[0].group(1)
300
        # except (AttributeError, IndexError):
301
        #     r = topk_answer[0]
302
        refine1 = re.sub('`QUESTION:.*?`ANSWER:','' , str(raw_output[0]) , flags=re.DOTALL)
303
        refine2 = refine1.split('`QUESTION: ')[0]
304
        return refine2
305
306
307
if __name__ == "__main__":
308
    gen = GenerateQADoc()
309
    print(gen.predict('my eyes hurt'))