|
a |
|
b/docproduct/predictor.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
import re |
|
|
4 |
from collections import defaultdict |
|
|
5 |
from multiprocessing import Pool, cpu_count |
|
|
6 |
from time import time |
|
|
7 |
|
|
|
8 |
import faiss |
|
|
9 |
import numpy as np |
|
|
10 |
import pandas as pd |
|
|
11 |
import tensorflow as tf |
|
|
12 |
from tqdm import tqdm |
|
|
13 |
|
|
|
14 |
import gpt2_estimator |
|
|
15 |
from docproduct.dataset import convert_text_to_feature |
|
|
16 |
from docproduct.models import MedicalQAModelwithBert |
|
|
17 |
from docproduct.tokenization import FullTokenizer |
|
|
18 |
from keras_bert.loader import checkpoint_loader |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def load_weight(model, bert_ffn_weight_file=None, ffn_weight_file=None): |
|
|
22 |
if bert_ffn_weight_file: |
|
|
23 |
model.load_weights(bert_ffn_weight_file) |
|
|
24 |
elif ffn_weight_file: |
|
|
25 |
loader = checkpoint_loader(ffn_weight_file) |
|
|
26 |
model.get_layer('q_ffn').set_weights( |
|
|
27 |
[loader('q_ffn/ffn_layer/kernel/.ATTRIBUTES/VARIABLE_VALUE'), |
|
|
28 |
loader('q_ffn/ffn_layer/bias/.ATTRIBUTES/VARIABLE_VALUE')]) |
|
|
29 |
model.get_layer('a_ffn').set_weights( |
|
|
30 |
[loader('a_ffn/ffn_layer/kernel/.ATTRIBUTES/VARIABLE_VALUE'), |
|
|
31 |
loader('a_ffn/ffn_layer/bias/.ATTRIBUTES/VARIABLE_VALUE')] |
|
|
32 |
) |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
class QAEmbed(object): |
|
|
36 |
def __init__( |
|
|
37 |
self, |
|
|
38 |
hidden_size=768, |
|
|
39 |
dropout=0.2, |
|
|
40 |
residual=True, |
|
|
41 |
pretrained_path=None, |
|
|
42 |
batch_size=128, |
|
|
43 |
max_seq_length=256, |
|
|
44 |
ffn_weight_file=None, |
|
|
45 |
bert_ffn_weight_file=None, |
|
|
46 |
load_pretrain=True, |
|
|
47 |
with_question=True, |
|
|
48 |
with_answer=True): |
|
|
49 |
super(QAEmbed, self).__init__() |
|
|
50 |
|
|
|
51 |
config_file = os.path.join(pretrained_path, 'bert_config.json') |
|
|
52 |
if load_pretrain: |
|
|
53 |
checkpoint_file = os.path.join( |
|
|
54 |
pretrained_path, 'biobert_model.ckpt') |
|
|
55 |
else: |
|
|
56 |
checkpoint_file = None |
|
|
57 |
|
|
|
58 |
# the ffn model takes 2nd to last layer |
|
|
59 |
if bert_ffn_weight_file is None: |
|
|
60 |
layer_ind = -2 |
|
|
61 |
else: |
|
|
62 |
layer_ind = -1 |
|
|
63 |
|
|
|
64 |
self.model = MedicalQAModelwithBert( |
|
|
65 |
hidden_size=768, |
|
|
66 |
dropout=0.2, |
|
|
67 |
residual=True, |
|
|
68 |
config_file=config_file, |
|
|
69 |
checkpoint_file=checkpoint_file, |
|
|
70 |
layer_ind=layer_ind) |
|
|
71 |
self.batch_size = batch_size |
|
|
72 |
self.tokenizer = FullTokenizer( |
|
|
73 |
os.path.join(pretrained_path, 'vocab.txt')) |
|
|
74 |
self.max_seq_length = max_seq_length |
|
|
75 |
|
|
|
76 |
# build mode in order to load |
|
|
77 |
question = 'fake' if with_question else None |
|
|
78 |
answer = 'fake' if with_answer else None |
|
|
79 |
self.predict(questions=question, answers=answer, dataset=False) |
|
|
80 |
load_weight(self.model, bert_ffn_weight_file, ffn_weight_file) |
|
|
81 |
|
|
|
82 |
def _type_check(self, inputs): |
|
|
83 |
if inputs is not None: |
|
|
84 |
if isinstance(inputs, str): |
|
|
85 |
inputs = [inputs] |
|
|
86 |
elif isinstance(inputs, list): |
|
|
87 |
pass |
|
|
88 |
else: |
|
|
89 |
raise TypeError( |
|
|
90 |
'inputs are supposed to be str of list of str, got {0} instead.'.format(type(inputs))) |
|
|
91 |
return inputs |
|
|
92 |
|
|
|
93 |
def _make_inputs(self, questions=None, answers=None, dataset=True): |
|
|
94 |
|
|
|
95 |
if questions: |
|
|
96 |
data_size = len(questions) |
|
|
97 |
q_feature_dict = defaultdict(list) |
|
|
98 |
for q in questions: |
|
|
99 |
q_feature = convert_text_to_feature( |
|
|
100 |
q, tokenizer=self.tokenizer, max_seq_length=self.max_seq_length) |
|
|
101 |
q_feature_dict['q_input_ids'].append(q_feature[0]) |
|
|
102 |
q_feature_dict['q_input_masks'].append(q_feature[1]) |
|
|
103 |
q_feature_dict['q_segment_ids'].append(q_feature[2]) |
|
|
104 |
|
|
|
105 |
if answers: |
|
|
106 |
data_size = len(answers) |
|
|
107 |
a_feature_dict = defaultdict(list) |
|
|
108 |
for a in answers: |
|
|
109 |
a_feature = convert_text_to_feature( |
|
|
110 |
a, tokenizer=self.tokenizer, max_seq_length=self.max_seq_length) |
|
|
111 |
a_feature_dict['a_input_ids'].append(a_feature[0]) |
|
|
112 |
a_feature_dict['a_input_masks'].append(a_feature[1]) |
|
|
113 |
a_feature_dict['a_segment_ids'].append(a_feature[2]) |
|
|
114 |
|
|
|
115 |
if questions and answers: |
|
|
116 |
q_feature_dict.update(a_feature_dict) |
|
|
117 |
model_inputs = q_feature_dict |
|
|
118 |
elif questions: |
|
|
119 |
model_inputs = q_feature_dict |
|
|
120 |
elif answers: |
|
|
121 |
model_inputs = a_feature_dict |
|
|
122 |
|
|
|
123 |
model_inputs = {k: tf.convert_to_tensor( |
|
|
124 |
np.stack(v, axis=0)) for k, v in model_inputs.items()} |
|
|
125 |
if dataset: |
|
|
126 |
model_inputs = tf.data.Dataset.from_tensor_slices(model_inputs) |
|
|
127 |
model_inputs = model_inputs.batch(self.batch_size) |
|
|
128 |
|
|
|
129 |
return model_inputs |
|
|
130 |
|
|
|
131 |
def predict(self, questions=None, answers=None, dataset=True): |
|
|
132 |
|
|
|
133 |
# type check |
|
|
134 |
questions = self._type_check(questions) |
|
|
135 |
answers = self._type_check(answers) |
|
|
136 |
|
|
|
137 |
if questions is not None and answers is not None: |
|
|
138 |
assert len(questions) == len(answers) |
|
|
139 |
|
|
|
140 |
model_inputs = self._make_inputs(questions, answers, dataset) |
|
|
141 |
model_outputs = [] |
|
|
142 |
|
|
|
143 |
if dataset: |
|
|
144 |
for batch in tqdm(iter(model_inputs), total=int(len(questions) / self.batch_size)): |
|
|
145 |
model_outputs.append(self.model(batch)) |
|
|
146 |
model_outputs = np.concatenate(model_outputs, axis=0) |
|
|
147 |
else: |
|
|
148 |
model_outputs = self.model(model_inputs) |
|
|
149 |
return model_outputs |
|
|
150 |
|
|
|
151 |
|
|
|
152 |
class FaissTopK(object): |
|
|
153 |
def __init__(self, embedding_file): |
|
|
154 |
super(FaissTopK, self).__init__() |
|
|
155 |
self.embedding_file = embedding_file |
|
|
156 |
_, ext = os.path.splitext(self.embedding_file) |
|
|
157 |
if ext == '.pkl': |
|
|
158 |
self.df = pd.read_pickle(self.embedding_file) |
|
|
159 |
else: |
|
|
160 |
self.df = pd.read_parquet(self.embedding_file) |
|
|
161 |
self._get_faiss_index() |
|
|
162 |
# self.df.drop(columns=["Q_FFNN_embeds", "A_FFNN_embeds"], inplace=True) |
|
|
163 |
|
|
|
164 |
def _get_faiss_index(self): |
|
|
165 |
# with Pool(cpu_count()) as p: |
|
|
166 |
# question_bert = p.map(eval, self.df["Q_FFNN_embeds"].tolist()) |
|
|
167 |
# answer_bert = p.map(eval, self.df["A_FFNN_embeds"].tolist()) |
|
|
168 |
question_bert = self.df["Q_FFNN_embeds"].tolist() |
|
|
169 |
self.df.drop(columns=["Q_FFNN_embeds"], inplace=True) |
|
|
170 |
answer_bert = self.df["A_FFNN_embeds"].tolist() |
|
|
171 |
self.df.drop(columns=["A_FFNN_embeds"], inplace=True) |
|
|
172 |
question_bert = np.array(question_bert, dtype='float32') |
|
|
173 |
answer_bert = np.array(answer_bert, dtype='float32') |
|
|
174 |
|
|
|
175 |
self.answer_index = faiss.IndexFlatIP(answer_bert.shape[-1]) |
|
|
176 |
|
|
|
177 |
self.question_index = faiss.IndexFlatIP(question_bert.shape[-1]) |
|
|
178 |
|
|
|
179 |
self.answer_index.add(answer_bert) |
|
|
180 |
self.question_index.add(question_bert) |
|
|
181 |
|
|
|
182 |
del answer_bert, question_bert |
|
|
183 |
|
|
|
184 |
def predict(self, q_embedding, search_by='answer', topk=5, answer_only=True): |
|
|
185 |
if search_by == 'answer': |
|
|
186 |
_, index = self.answer_index.search( |
|
|
187 |
q_embedding.astype('float32'), topk) |
|
|
188 |
else: |
|
|
189 |
_, index = self.question_index.search( |
|
|
190 |
q_embedding.astype('float32'), topk) |
|
|
191 |
|
|
|
192 |
output_df = self.df.iloc[index[0], :] |
|
|
193 |
if answer_only: |
|
|
194 |
return output_df.answer.tolist() |
|
|
195 |
else: |
|
|
196 |
return (output_df.question.tolist(), output_df.answer.tolist()) |
|
|
197 |
|
|
|
198 |
|
|
|
199 |
class RetreiveQADoc(object): |
|
|
200 |
def __init__(self, |
|
|
201 |
pretrained_path=None, |
|
|
202 |
ffn_weight_file=None, |
|
|
203 |
bert_ffn_weight_file='models/bertffn_crossentropy/bertffn', |
|
|
204 |
embedding_file='qa_embeddings/bertffn_crossentropy.zip' |
|
|
205 |
): |
|
|
206 |
super(RetreiveQADoc, self).__init__() |
|
|
207 |
self.qa_embed = QAEmbed( |
|
|
208 |
pretrained_path=pretrained_path, |
|
|
209 |
ffn_weight_file=ffn_weight_file, |
|
|
210 |
bert_ffn_weight_file=bert_ffn_weight_file |
|
|
211 |
) |
|
|
212 |
self.faiss_topk = FaissTopK(embedding_file) |
|
|
213 |
|
|
|
214 |
def predict(self, questions, search_by='answer', topk=5, answer_only=True): |
|
|
215 |
embedding = self.qa_embed.predict(questions=questions) |
|
|
216 |
return self.faiss_topk.predict(embedding, search_by, topk, answer_only) |
|
|
217 |
|
|
|
218 |
def getEmbedding(self, questions, search_by='answer', topk=5, answer_only=True): |
|
|
219 |
embedding = self.qa_embed.predict(questions=questions) |
|
|
220 |
return embedding |
|
|
221 |
|
|
|
222 |
|
|
|
223 |
class GenerateQADoc(object): |
|
|
224 |
def __init__(self, |
|
|
225 |
pretrained_path='models/pubmed_pmc_470k/', |
|
|
226 |
ffn_weight_file=None, |
|
|
227 |
bert_ffn_weight_file='models/bertffn_crossentropy/bertffn', |
|
|
228 |
gpt2_weight_file='models/gpt2', |
|
|
229 |
embedding_file='qa_embeddings/bertffn_crossentropy.zip' |
|
|
230 |
): |
|
|
231 |
super(GenerateQADoc, self).__init__() |
|
|
232 |
tf.compat.v1.disable_eager_execution() |
|
|
233 |
session_config = tf.compat.v1.ConfigProto( |
|
|
234 |
allow_soft_placement=True) |
|
|
235 |
session_config.gpu_options.allow_growth = False |
|
|
236 |
config = tf.estimator.RunConfig( |
|
|
237 |
session_config=session_config) |
|
|
238 |
self.batch_size = 1 |
|
|
239 |
self.gpt2_weight_file = gpt2_weight_file |
|
|
240 |
gpt2_model_fn = gpt2_estimator.get_gpt2_model_fn( |
|
|
241 |
accumulate_gradients=5, |
|
|
242 |
learning_rate=0.1, |
|
|
243 |
length=512, |
|
|
244 |
batch_size=self.batch_size, |
|
|
245 |
temperature=0.7, |
|
|
246 |
top_k=0 |
|
|
247 |
) |
|
|
248 |
hparams = gpt2_estimator.default_hparams() |
|
|
249 |
with open(os.path.join(gpt2_weight_file, 'hparams.json')) as f: |
|
|
250 |
hparams.override_from_dict(json.load(f)) |
|
|
251 |
self.estimator = tf.estimator.Estimator( |
|
|
252 |
gpt2_model_fn, |
|
|
253 |
model_dir=gpt2_weight_file, |
|
|
254 |
params=hparams, |
|
|
255 |
config=config) |
|
|
256 |
self.encoder = gpt2_estimator.encoder.get_encoder(gpt2_weight_file) |
|
|
257 |
|
|
|
258 |
config = tf.compat.v1.ConfigProto() |
|
|
259 |
config.gpu_options.allow_growth = True |
|
|
260 |
self.embed_sess = tf.compat.v1.Session(config=config) |
|
|
261 |
with self.embed_sess.as_default(): |
|
|
262 |
self.qa_embed = QAEmbed( |
|
|
263 |
pretrained_path=pretrained_path, |
|
|
264 |
ffn_weight_file=ffn_weight_file, |
|
|
265 |
bert_ffn_weight_file=bert_ffn_weight_file, |
|
|
266 |
with_answer=False, |
|
|
267 |
load_pretrain=False |
|
|
268 |
) |
|
|
269 |
|
|
|
270 |
self.faiss_topk = FaissTopK(embedding_file) |
|
|
271 |
|
|
|
272 |
def _get_gpt2_inputs(self, question, questions, answers): |
|
|
273 |
assert len(questions) == len(answers) |
|
|
274 |
line = '`QUESTION: %s `ANSWER: ' % question |
|
|
275 |
for q, a in zip(questions, answers): |
|
|
276 |
line = '`QUESTION: %s `ANSWER: %s ' % (q, a) + line |
|
|
277 |
return line |
|
|
278 |
|
|
|
279 |
def predict(self, questions, search_by='answer', topk=5, answer_only=False): |
|
|
280 |
embedding = self.qa_embed.predict( |
|
|
281 |
questions=questions, dataset=False).eval(session=self.embed_sess) |
|
|
282 |
if answer_only: |
|
|
283 |
topk_answer = self.faiss_topk.predict( |
|
|
284 |
embedding, search_by, topk, answer_only) |
|
|
285 |
else: |
|
|
286 |
topk_question, topk_answer = self.faiss_topk.predict( |
|
|
287 |
embedding, search_by, topk, answer_only) |
|
|
288 |
|
|
|
289 |
gpt2_input = self._get_gpt2_inputs( |
|
|
290 |
questions[0], topk_question, topk_answer) |
|
|
291 |
gpt2_pred = self.estimator.predict( |
|
|
292 |
lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size, checkpoint_path=self.gpt2_weight_file)) |
|
|
293 |
raw_output = gpt2_estimator.predictions_parsing( |
|
|
294 |
gpt2_pred, self.encoder) |
|
|
295 |
# result_list = [re.search('`ANSWER:(.*)`QUESTION:', s) |
|
|
296 |
# for s in raw_output] |
|
|
297 |
# result_list = [s for s in result_list if s] |
|
|
298 |
# try: |
|
|
299 |
# r = result_list[0].group(1) |
|
|
300 |
# except (AttributeError, IndexError): |
|
|
301 |
# r = topk_answer[0] |
|
|
302 |
refine1 = re.sub('`QUESTION:.*?`ANSWER:','' , str(raw_output[0]) , flags=re.DOTALL) |
|
|
303 |
refine2 = refine1.split('`QUESTION: ')[0] |
|
|
304 |
return refine2 |
|
|
305 |
|
|
|
306 |
|
|
|
307 |
if __name__ == "__main__": |
|
|
308 |
gen = GenerateQADoc() |
|
|
309 |
print(gen.predict('my eyes hurt')) |