--- a
+++ b/docproduct/predictor.py
@@ -0,0 +1,309 @@
+import json
+import os
+import re
+from collections import defaultdict
+from multiprocessing import Pool, cpu_count
+from time import time
+
+import faiss
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tqdm import tqdm
+
+import gpt2_estimator
+from docproduct.dataset import convert_text_to_feature
+from docproduct.models import MedicalQAModelwithBert
+from docproduct.tokenization import FullTokenizer
+from keras_bert.loader import checkpoint_loader
+
+
+def load_weight(model, bert_ffn_weight_file=None, ffn_weight_file=None):
+    if bert_ffn_weight_file:
+        model.load_weights(bert_ffn_weight_file)
+    elif ffn_weight_file:
+        loader = checkpoint_loader(ffn_weight_file)
+        model.get_layer('q_ffn').set_weights(
+            [loader('q_ffn/ffn_layer/kernel/.ATTRIBUTES/VARIABLE_VALUE'),
+             loader('q_ffn/ffn_layer/bias/.ATTRIBUTES/VARIABLE_VALUE')])
+        model.get_layer('a_ffn').set_weights(
+            [loader('a_ffn/ffn_layer/kernel/.ATTRIBUTES/VARIABLE_VALUE'),
+             loader('a_ffn/ffn_layer/bias/.ATTRIBUTES/VARIABLE_VALUE')]
+        )
+
+
+class QAEmbed(object):
+    def __init__(
+            self,
+            hidden_size=768,
+            dropout=0.2,
+            residual=True,
+            pretrained_path=None,
+            batch_size=128,
+            max_seq_length=256,
+            ffn_weight_file=None,
+            bert_ffn_weight_file=None,
+            load_pretrain=True,
+            with_question=True,
+            with_answer=True):
+        super(QAEmbed, self).__init__()
+
+        config_file = os.path.join(pretrained_path, 'bert_config.json')
+        if load_pretrain:
+            checkpoint_file = os.path.join(
+                pretrained_path, 'biobert_model.ckpt')
+        else:
+            checkpoint_file = None
+
+        # the ffn model takes 2nd to last layer
+        if bert_ffn_weight_file is None:
+            layer_ind = -2
+        else:
+            layer_ind = -1
+
+        self.model = MedicalQAModelwithBert(
+            hidden_size=768,
+            dropout=0.2,
+            residual=True,
+            config_file=config_file,
+            checkpoint_file=checkpoint_file,
+            layer_ind=layer_ind)
+        self.batch_size = batch_size
+        self.tokenizer = FullTokenizer(
+            os.path.join(pretrained_path, 'vocab.txt'))
+        self.max_seq_length = max_seq_length
+
+        # build mode in order to load
+        question = 'fake' if with_question else None
+        answer = 'fake' if with_answer else None
+        self.predict(questions=question, answers=answer, dataset=False)
+        load_weight(self.model, bert_ffn_weight_file, ffn_weight_file)
+
+    def _type_check(self, inputs):
+        if inputs is not None:
+            if isinstance(inputs, str):
+                inputs = [inputs]
+            elif isinstance(inputs, list):
+                pass
+            else:
+                raise TypeError(
+                    'inputs are supposed to be str of list of str, got {0} instead.'.format(type(inputs)))
+            return inputs
+
+    def _make_inputs(self, questions=None, answers=None, dataset=True):
+
+        if questions:
+            data_size = len(questions)
+            q_feature_dict = defaultdict(list)
+            for q in questions:
+                q_feature = convert_text_to_feature(
+                    q, tokenizer=self.tokenizer, max_seq_length=self.max_seq_length)
+                q_feature_dict['q_input_ids'].append(q_feature[0])
+                q_feature_dict['q_input_masks'].append(q_feature[1])
+                q_feature_dict['q_segment_ids'].append(q_feature[2])
+
+        if answers:
+            data_size = len(answers)
+            a_feature_dict = defaultdict(list)
+            for a in answers:
+                a_feature = convert_text_to_feature(
+                    a, tokenizer=self.tokenizer, max_seq_length=self.max_seq_length)
+                a_feature_dict['a_input_ids'].append(a_feature[0])
+                a_feature_dict['a_input_masks'].append(a_feature[1])
+                a_feature_dict['a_segment_ids'].append(a_feature[2])
+
+        if questions and answers:
+            q_feature_dict.update(a_feature_dict)
+            model_inputs = q_feature_dict
+        elif questions:
+            model_inputs = q_feature_dict
+        elif answers:
+            model_inputs = a_feature_dict
+
+        model_inputs = {k: tf.convert_to_tensor(
+            np.stack(v, axis=0)) for k, v in model_inputs.items()}
+        if dataset:
+            model_inputs = tf.data.Dataset.from_tensor_slices(model_inputs)
+            model_inputs = model_inputs.batch(self.batch_size)
+
+        return model_inputs
+
+    def predict(self, questions=None, answers=None, dataset=True):
+
+        # type check
+        questions = self._type_check(questions)
+        answers = self._type_check(answers)
+
+        if questions is not None and answers is not None:
+            assert len(questions) == len(answers)
+
+        model_inputs = self._make_inputs(questions, answers, dataset)
+        model_outputs = []
+
+        if dataset:
+            for batch in tqdm(iter(model_inputs), total=int(len(questions) / self.batch_size)):
+                model_outputs.append(self.model(batch))
+            model_outputs = np.concatenate(model_outputs, axis=0)
+        else:
+            model_outputs = self.model(model_inputs)
+        return model_outputs
+
+
+class FaissTopK(object):
+    def __init__(self, embedding_file):
+        super(FaissTopK, self).__init__()
+        self.embedding_file = embedding_file
+        _, ext = os.path.splitext(self.embedding_file)
+        if ext == '.pkl':
+            self.df = pd.read_pickle(self.embedding_file)
+        else:
+            self.df = pd.read_parquet(self.embedding_file)
+        self._get_faiss_index()
+        # self.df.drop(columns=["Q_FFNN_embeds", "A_FFNN_embeds"], inplace=True)
+
+    def _get_faiss_index(self):
+        # with Pool(cpu_count()) as p:
+        #     question_bert = p.map(eval, self.df["Q_FFNN_embeds"].tolist())
+        #     answer_bert = p.map(eval, self.df["A_FFNN_embeds"].tolist())
+        question_bert = self.df["Q_FFNN_embeds"].tolist()
+        self.df.drop(columns=["Q_FFNN_embeds"], inplace=True)
+        answer_bert = self.df["A_FFNN_embeds"].tolist()
+        self.df.drop(columns=["A_FFNN_embeds"], inplace=True)
+        question_bert = np.array(question_bert, dtype='float32')
+        answer_bert = np.array(answer_bert, dtype='float32')
+
+        self.answer_index = faiss.IndexFlatIP(answer_bert.shape[-1])
+
+        self.question_index = faiss.IndexFlatIP(question_bert.shape[-1])
+
+        self.answer_index.add(answer_bert)
+        self.question_index.add(question_bert)
+
+        del answer_bert, question_bert
+
+    def predict(self, q_embedding, search_by='answer', topk=5, answer_only=True):
+        if search_by == 'answer':
+            _, index = self.answer_index.search(
+                q_embedding.astype('float32'), topk)
+        else:
+            _, index = self.question_index.search(
+                q_embedding.astype('float32'), topk)
+
+        output_df = self.df.iloc[index[0], :]
+        if answer_only:
+            return output_df.answer.tolist()
+        else:
+            return (output_df.question.tolist(), output_df.answer.tolist())
+
+
+class RetreiveQADoc(object):
+    def __init__(self,
+                 pretrained_path=None,
+                 ffn_weight_file=None,
+                 bert_ffn_weight_file='models/bertffn_crossentropy/bertffn',
+                 embedding_file='qa_embeddings/bertffn_crossentropy.zip'
+                 ):
+        super(RetreiveQADoc, self).__init__()
+        self.qa_embed = QAEmbed(
+            pretrained_path=pretrained_path,
+            ffn_weight_file=ffn_weight_file,
+            bert_ffn_weight_file=bert_ffn_weight_file
+        )
+        self.faiss_topk = FaissTopK(embedding_file)
+
+    def predict(self, questions, search_by='answer', topk=5, answer_only=True):
+        embedding = self.qa_embed.predict(questions=questions)
+        return self.faiss_topk.predict(embedding, search_by, topk, answer_only)
+
+    def getEmbedding(self, questions, search_by='answer', topk=5, answer_only=True):
+        embedding = self.qa_embed.predict(questions=questions)
+        return embedding
+
+
+class GenerateQADoc(object):
+    def __init__(self,
+                 pretrained_path='models/pubmed_pmc_470k/',
+                 ffn_weight_file=None,
+                 bert_ffn_weight_file='models/bertffn_crossentropy/bertffn',
+                 gpt2_weight_file='models/gpt2',
+                 embedding_file='qa_embeddings/bertffn_crossentropy.zip'
+                 ):
+        super(GenerateQADoc, self).__init__()
+        tf.compat.v1.disable_eager_execution()
+        session_config = tf.compat.v1.ConfigProto(
+            allow_soft_placement=True)
+        session_config.gpu_options.allow_growth = False
+        config = tf.estimator.RunConfig(
+            session_config=session_config)
+        self.batch_size = 1
+        self.gpt2_weight_file = gpt2_weight_file
+        gpt2_model_fn = gpt2_estimator.get_gpt2_model_fn(
+            accumulate_gradients=5,
+            learning_rate=0.1,
+            length=512,
+            batch_size=self.batch_size,
+            temperature=0.7,
+            top_k=0
+        )
+        hparams = gpt2_estimator.default_hparams()
+        with open(os.path.join(gpt2_weight_file, 'hparams.json')) as f:
+            hparams.override_from_dict(json.load(f))
+        self.estimator = tf.estimator.Estimator(
+            gpt2_model_fn,
+            model_dir=gpt2_weight_file,
+            params=hparams,
+            config=config)
+        self.encoder = gpt2_estimator.encoder.get_encoder(gpt2_weight_file)
+
+        config = tf.compat.v1.ConfigProto()
+        config.gpu_options.allow_growth = True
+        self.embed_sess = tf.compat.v1.Session(config=config)
+        with self.embed_sess.as_default():
+            self.qa_embed = QAEmbed(
+                pretrained_path=pretrained_path,
+                ffn_weight_file=ffn_weight_file,
+                bert_ffn_weight_file=bert_ffn_weight_file,
+                with_answer=False,
+                load_pretrain=False
+            )
+
+        self.faiss_topk = FaissTopK(embedding_file)
+
+    def _get_gpt2_inputs(self, question, questions, answers):
+        assert len(questions) == len(answers)
+        line = '`QUESTION: %s `ANSWER: ' % question
+        for q, a in zip(questions, answers):
+            line = '`QUESTION: %s `ANSWER: %s ' % (q, a) + line
+        return line
+
+    def predict(self, questions, search_by='answer', topk=5, answer_only=False):
+        embedding = self.qa_embed.predict(
+            questions=questions, dataset=False).eval(session=self.embed_sess)
+        if answer_only:
+            topk_answer = self.faiss_topk.predict(
+                embedding, search_by, topk, answer_only)
+        else:
+            topk_question, topk_answer = self.faiss_topk.predict(
+                embedding, search_by, topk, answer_only)
+
+        gpt2_input = self._get_gpt2_inputs(
+            questions[0], topk_question, topk_answer)
+        gpt2_pred = self.estimator.predict(
+            lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size, checkpoint_path=self.gpt2_weight_file))
+        raw_output = gpt2_estimator.predictions_parsing(
+            gpt2_pred, self.encoder)
+        # result_list = [re.search('`ANSWER:(.*)`QUESTION:', s)
+        #                for s in raw_output]
+        # result_list = [s for s in result_list if s]
+        # try:
+        #     r = result_list[0].group(1)
+        # except (AttributeError, IndexError):
+        #     r = topk_answer[0]
+        refine1 = re.sub('`QUESTION:.*?`ANSWER:','' , str(raw_output[0]) , flags=re.DOTALL)
+        refine2 = refine1.split('`QUESTION: ')[0]
+        return refine2
+
+
+if __name__ == "__main__":
+    gen = GenerateQADoc()
+    print(gen.predict('my eyes hurt'))