a b/modules/chatbot/inferencer.py
1
import numpy as np
2
import tensorflow as tf
3
from typing import List
4
from nltk.translate.bleu_score import sentence_bleu
5
from modules.chatbot.preprocessor import preprocess
6
7
8
class Inferencer:
9
    def __init__(
10
        self,
11
        medical_qa_gpt_model: tf.keras.Model,
12
        bert_tokenizer: tf.keras.preprocessing.text.Tokenizer,
13
        gpt_tokenizer: tf.keras.preprocessing.text.Tokenizer,
14
        question_extractor_model: tf.keras.Model,
15
        df_qa: pd.DataFrame,
16
        answer_index: faiss.IndexFlatIP,
17
        answer_len: int,
18
    ) -> None:
19
        """
20
        Initialize Inferencer with necessary components.
21
22
        Args:
23
            medical_qa_gpt_model (tf.keras.Model): Medical Q&A GPT model.
24
            bert_tokenizer (tf.keras.preprocessing.text.Tokenizer): BERT tokenizer.
25
            gpt_tokenizer (tf.keras.preprocessing.text.Tokenizer): GPT tokenizer.
26
            question_extractor_model (tf.keras.Model): Question extractor model.
27
            df_qa (pd.DataFrame): DataFrame containing Q&A pairs.
28
            answer_index (faiss.IndexFlatIP): FAISS index for answers.
29
            answer_len (int): Length of the answer.
30
        """
31
        self.biobert_tokenizer = bert_tokenizer
32
        self.question_extractor_model = question_extractor_model
33
        self.answer_index = answer_index
34
        self.gpt_tokenizer = gpt_tokenizer
35
        self.medical_qa_gpt_model = medical_qa_gpt_model
36
        self.df_qa = df_qa
37
        self.answer_len = answer_len
38
39
    def get_gpt_inference_data(
40
        self, question: str, question_embedding: np.ndarray
41
    ) -> List[int]:
42
        """
43
        Get GPT inference data.
44
45
        Args:
46
            question (str): Input question.
47
            question_embedding (np.ndarray): Embedding of the question.
48
49
        Returns:
50
            List[int]: GPT inference data.
51
        """
52
        topk = 20
53
        scores, indices = self.answer_index.search(
54
            question_embedding.astype("float32"), topk
55
        )
56
        q_sub = self.df_qa.iloc[indices.reshape(20)]
57
        line = "`QUESTION: %s `ANSWER: " % (question)
58
        encoded_len = len(self.gpt_tokenizer.encode(line))
59
        for i in q_sub.iterrows():
60
            line = (
61
                "`QUESTION: %s `ANSWER: %s " % (i[1]["question"], i[1]["answer"]) + line
62
            )
63
            line = line.replace("\n", "")
64
            encoded_len = len(self.gpt_tokenizer.encode(line))
65
            if encoded_len >= 1024:
66
                break
67
        return self.gpt_tokenizer.encode(line)[-1024:]
68
69
    def get_gpt_answer(self, question: str, answer_len: int) -> str:
70
        """
71
        Get GPT answer.
72
73
        Args:
74
            question (str): Input question.
75
            answer_len (int): Length of the answer.
76
77
        Returns:
78
            str: GPT generated answer.
79
        """
80
        preprocessed_question = preprocess(question)
81
        truncated_question = (
82
            " ".join(preprocessed_question.split(" ")[:500])
83
            if len(preprocessed_question.split(" ")) > 500
84
            else preprocessed_question
85
        )
86
        encoded_question = self.biobert_tokenizer.encode(truncated_question)
87
        padded_question = tf.keras.preprocessing.sequence.pad_sequences(
88
            [encoded_question], maxlen=512, padding="post"
89
        )
90
        question_mask = np.where(padded_question != 0, 1, 0)
91
        embeddings = self.question_extractor_model(
92
            {"question": padded_question, "question_mask": question_mask}
93
        )
94
        gpt_input = self.get_gpt_inference_data(truncated_question, embeddings.numpy())
95
        mask_start = len(gpt_input) - list(gpt_input[::-1]).index(4600) + 1
96
        input = gpt_input[: mask_start + 1]
97
        if len(input) > (1024 - answer_len):
98
            input = input[-(1024 - answer_len) :]
99
        gpt2_output = self.gpt_tokenizer.decode(
100
            self.medical_qa_gpt_model.generate(
101
                input_ids=tf.constant([np.array(input)]),
102
                max_length=1024,
103
                temperature=0.7,
104
            )[0]
105
        )
106
        answer = gpt2_output.rindex("`ANSWER: ")
107
        return gpt2_output[answer + len("`ANSWER: ") :]
108
109
    def inf_func(self, question: str) -> str:
110
        """
111
        Run inference for the given question.
112
113
        Args:
114
            question (str): Input question.
115
116
        Returns:
117
            str: Generated answer.
118
        """
119
        answer_len = self.answer_len
120
        return self.get_gpt_answer(question, answer_len)
121
122
    def eval_func(self, question: str, answer: str) -> float:
123
        """
124
        Evaluate generated answer against ground truth.
125
126
        Args:
127
            question (str): Input question.
128
            answer (str): Generated answer.
129
130
        Returns:
131
            float: BLEU score.
132
        """
133
        answer_len = 20
134
        generated_answer = self.get_gpt_answer(question, answer_len)
135
        reference = [answer.split(" ")]
136
        candidate = generated_answer.split(" ")
137
        score = sentence_bleu(reference, candidate)
138
        return score
139
140
    def run(self, question: str, isEval: bool) -> str:
141
        """
142
        Run inference for the given question.
143
144
        Args:
145
            question (str): Input question.
146
            isEval (bool): Whether to evaluate or not.
147
148
        Returns:
149
            str: Generated answer.
150
        """
151
        answer = self.inf_func(question)
152
        if isEval:
153
            bleu_score = self.eval_func(question, answer)
154
            print(f"The sentence_bleu score is {bleu_score}")
155
        return answer