Switch to side-by-side view

--- a
+++ b/modules/chatbot/inferencer.py
@@ -0,0 +1,155 @@
+import numpy as np
+import tensorflow as tf
+from typing import List
+from nltk.translate.bleu_score import sentence_bleu
+from modules.chatbot.preprocessor import preprocess
+
+
+class Inferencer:
+    def __init__(
+        self,
+        medical_qa_gpt_model: tf.keras.Model,
+        bert_tokenizer: tf.keras.preprocessing.text.Tokenizer,
+        gpt_tokenizer: tf.keras.preprocessing.text.Tokenizer,
+        question_extractor_model: tf.keras.Model,
+        df_qa: pd.DataFrame,
+        answer_index: faiss.IndexFlatIP,
+        answer_len: int,
+    ) -> None:
+        """
+        Initialize Inferencer with necessary components.
+
+        Args:
+            medical_qa_gpt_model (tf.keras.Model): Medical Q&A GPT model.
+            bert_tokenizer (tf.keras.preprocessing.text.Tokenizer): BERT tokenizer.
+            gpt_tokenizer (tf.keras.preprocessing.text.Tokenizer): GPT tokenizer.
+            question_extractor_model (tf.keras.Model): Question extractor model.
+            df_qa (pd.DataFrame): DataFrame containing Q&A pairs.
+            answer_index (faiss.IndexFlatIP): FAISS index for answers.
+            answer_len (int): Length of the answer.
+        """
+        self.biobert_tokenizer = bert_tokenizer
+        self.question_extractor_model = question_extractor_model
+        self.answer_index = answer_index
+        self.gpt_tokenizer = gpt_tokenizer
+        self.medical_qa_gpt_model = medical_qa_gpt_model
+        self.df_qa = df_qa
+        self.answer_len = answer_len
+
+    def get_gpt_inference_data(
+        self, question: str, question_embedding: np.ndarray
+    ) -> List[int]:
+        """
+        Get GPT inference data.
+
+        Args:
+            question (str): Input question.
+            question_embedding (np.ndarray): Embedding of the question.
+
+        Returns:
+            List[int]: GPT inference data.
+        """
+        topk = 20
+        scores, indices = self.answer_index.search(
+            question_embedding.astype("float32"), topk
+        )
+        q_sub = self.df_qa.iloc[indices.reshape(20)]
+        line = "`QUESTION: %s `ANSWER: " % (question)
+        encoded_len = len(self.gpt_tokenizer.encode(line))
+        for i in q_sub.iterrows():
+            line = (
+                "`QUESTION: %s `ANSWER: %s " % (i[1]["question"], i[1]["answer"]) + line
+            )
+            line = line.replace("\n", "")
+            encoded_len = len(self.gpt_tokenizer.encode(line))
+            if encoded_len >= 1024:
+                break
+        return self.gpt_tokenizer.encode(line)[-1024:]
+
+    def get_gpt_answer(self, question: str, answer_len: int) -> str:
+        """
+        Get GPT answer.
+
+        Args:
+            question (str): Input question.
+            answer_len (int): Length of the answer.
+
+        Returns:
+            str: GPT generated answer.
+        """
+        preprocessed_question = preprocess(question)
+        truncated_question = (
+            " ".join(preprocessed_question.split(" ")[:500])
+            if len(preprocessed_question.split(" ")) > 500
+            else preprocessed_question
+        )
+        encoded_question = self.biobert_tokenizer.encode(truncated_question)
+        padded_question = tf.keras.preprocessing.sequence.pad_sequences(
+            [encoded_question], maxlen=512, padding="post"
+        )
+        question_mask = np.where(padded_question != 0, 1, 0)
+        embeddings = self.question_extractor_model(
+            {"question": padded_question, "question_mask": question_mask}
+        )
+        gpt_input = self.get_gpt_inference_data(truncated_question, embeddings.numpy())
+        mask_start = len(gpt_input) - list(gpt_input[::-1]).index(4600) + 1
+        input = gpt_input[: mask_start + 1]
+        if len(input) > (1024 - answer_len):
+            input = input[-(1024 - answer_len) :]
+        gpt2_output = self.gpt_tokenizer.decode(
+            self.medical_qa_gpt_model.generate(
+                input_ids=tf.constant([np.array(input)]),
+                max_length=1024,
+                temperature=0.7,
+            )[0]
+        )
+        answer = gpt2_output.rindex("`ANSWER: ")
+        return gpt2_output[answer + len("`ANSWER: ") :]
+
+    def inf_func(self, question: str) -> str:
+        """
+        Run inference for the given question.
+
+        Args:
+            question (str): Input question.
+
+        Returns:
+            str: Generated answer.
+        """
+        answer_len = self.answer_len
+        return self.get_gpt_answer(question, answer_len)
+
+    def eval_func(self, question: str, answer: str) -> float:
+        """
+        Evaluate generated answer against ground truth.
+
+        Args:
+            question (str): Input question.
+            answer (str): Generated answer.
+
+        Returns:
+            float: BLEU score.
+        """
+        answer_len = 20
+        generated_answer = self.get_gpt_answer(question, answer_len)
+        reference = [answer.split(" ")]
+        candidate = generated_answer.split(" ")
+        score = sentence_bleu(reference, candidate)
+        return score
+
+    def run(self, question: str, isEval: bool) -> str:
+        """
+        Run inference for the given question.
+
+        Args:
+            question (str): Input question.
+            isEval (bool): Whether to evaluate or not.
+
+        Returns:
+            str: Generated answer.
+        """
+        answer = self.inf_func(question)
+        if isEval:
+            bleu_score = self.eval_func(question, answer)
+            print(f"The sentence_bleu score is {bleu_score}")
+        return answer