|
a |
|
b/modules/chatbot/inferencer.py |
|
|
1 |
import numpy as np |
|
|
2 |
import tensorflow as tf |
|
|
3 |
from typing import List |
|
|
4 |
from nltk.translate.bleu_score import sentence_bleu |
|
|
5 |
from modules.chatbot.preprocessor import preprocess |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class Inferencer: |
|
|
9 |
def __init__( |
|
|
10 |
self, |
|
|
11 |
medical_qa_gpt_model: tf.keras.Model, |
|
|
12 |
bert_tokenizer: tf.keras.preprocessing.text.Tokenizer, |
|
|
13 |
gpt_tokenizer: tf.keras.preprocessing.text.Tokenizer, |
|
|
14 |
question_extractor_model: tf.keras.Model, |
|
|
15 |
df_qa: pd.DataFrame, |
|
|
16 |
answer_index: faiss.IndexFlatIP, |
|
|
17 |
answer_len: int, |
|
|
18 |
) -> None: |
|
|
19 |
""" |
|
|
20 |
Initialize Inferencer with necessary components. |
|
|
21 |
|
|
|
22 |
Args: |
|
|
23 |
medical_qa_gpt_model (tf.keras.Model): Medical Q&A GPT model. |
|
|
24 |
bert_tokenizer (tf.keras.preprocessing.text.Tokenizer): BERT tokenizer. |
|
|
25 |
gpt_tokenizer (tf.keras.preprocessing.text.Tokenizer): GPT tokenizer. |
|
|
26 |
question_extractor_model (tf.keras.Model): Question extractor model. |
|
|
27 |
df_qa (pd.DataFrame): DataFrame containing Q&A pairs. |
|
|
28 |
answer_index (faiss.IndexFlatIP): FAISS index for answers. |
|
|
29 |
answer_len (int): Length of the answer. |
|
|
30 |
""" |
|
|
31 |
self.biobert_tokenizer = bert_tokenizer |
|
|
32 |
self.question_extractor_model = question_extractor_model |
|
|
33 |
self.answer_index = answer_index |
|
|
34 |
self.gpt_tokenizer = gpt_tokenizer |
|
|
35 |
self.medical_qa_gpt_model = medical_qa_gpt_model |
|
|
36 |
self.df_qa = df_qa |
|
|
37 |
self.answer_len = answer_len |
|
|
38 |
|
|
|
39 |
def get_gpt_inference_data( |
|
|
40 |
self, question: str, question_embedding: np.ndarray |
|
|
41 |
) -> List[int]: |
|
|
42 |
""" |
|
|
43 |
Get GPT inference data. |
|
|
44 |
|
|
|
45 |
Args: |
|
|
46 |
question (str): Input question. |
|
|
47 |
question_embedding (np.ndarray): Embedding of the question. |
|
|
48 |
|
|
|
49 |
Returns: |
|
|
50 |
List[int]: GPT inference data. |
|
|
51 |
""" |
|
|
52 |
topk = 20 |
|
|
53 |
scores, indices = self.answer_index.search( |
|
|
54 |
question_embedding.astype("float32"), topk |
|
|
55 |
) |
|
|
56 |
q_sub = self.df_qa.iloc[indices.reshape(20)] |
|
|
57 |
line = "`QUESTION: %s `ANSWER: " % (question) |
|
|
58 |
encoded_len = len(self.gpt_tokenizer.encode(line)) |
|
|
59 |
for i in q_sub.iterrows(): |
|
|
60 |
line = ( |
|
|
61 |
"`QUESTION: %s `ANSWER: %s " % (i[1]["question"], i[1]["answer"]) + line |
|
|
62 |
) |
|
|
63 |
line = line.replace("\n", "") |
|
|
64 |
encoded_len = len(self.gpt_tokenizer.encode(line)) |
|
|
65 |
if encoded_len >= 1024: |
|
|
66 |
break |
|
|
67 |
return self.gpt_tokenizer.encode(line)[-1024:] |
|
|
68 |
|
|
|
69 |
def get_gpt_answer(self, question: str, answer_len: int) -> str: |
|
|
70 |
""" |
|
|
71 |
Get GPT answer. |
|
|
72 |
|
|
|
73 |
Args: |
|
|
74 |
question (str): Input question. |
|
|
75 |
answer_len (int): Length of the answer. |
|
|
76 |
|
|
|
77 |
Returns: |
|
|
78 |
str: GPT generated answer. |
|
|
79 |
""" |
|
|
80 |
preprocessed_question = preprocess(question) |
|
|
81 |
truncated_question = ( |
|
|
82 |
" ".join(preprocessed_question.split(" ")[:500]) |
|
|
83 |
if len(preprocessed_question.split(" ")) > 500 |
|
|
84 |
else preprocessed_question |
|
|
85 |
) |
|
|
86 |
encoded_question = self.biobert_tokenizer.encode(truncated_question) |
|
|
87 |
padded_question = tf.keras.preprocessing.sequence.pad_sequences( |
|
|
88 |
[encoded_question], maxlen=512, padding="post" |
|
|
89 |
) |
|
|
90 |
question_mask = np.where(padded_question != 0, 1, 0) |
|
|
91 |
embeddings = self.question_extractor_model( |
|
|
92 |
{"question": padded_question, "question_mask": question_mask} |
|
|
93 |
) |
|
|
94 |
gpt_input = self.get_gpt_inference_data(truncated_question, embeddings.numpy()) |
|
|
95 |
mask_start = len(gpt_input) - list(gpt_input[::-1]).index(4600) + 1 |
|
|
96 |
input = gpt_input[: mask_start + 1] |
|
|
97 |
if len(input) > (1024 - answer_len): |
|
|
98 |
input = input[-(1024 - answer_len) :] |
|
|
99 |
gpt2_output = self.gpt_tokenizer.decode( |
|
|
100 |
self.medical_qa_gpt_model.generate( |
|
|
101 |
input_ids=tf.constant([np.array(input)]), |
|
|
102 |
max_length=1024, |
|
|
103 |
temperature=0.7, |
|
|
104 |
)[0] |
|
|
105 |
) |
|
|
106 |
answer = gpt2_output.rindex("`ANSWER: ") |
|
|
107 |
return gpt2_output[answer + len("`ANSWER: ") :] |
|
|
108 |
|
|
|
109 |
def inf_func(self, question: str) -> str: |
|
|
110 |
""" |
|
|
111 |
Run inference for the given question. |
|
|
112 |
|
|
|
113 |
Args: |
|
|
114 |
question (str): Input question. |
|
|
115 |
|
|
|
116 |
Returns: |
|
|
117 |
str: Generated answer. |
|
|
118 |
""" |
|
|
119 |
answer_len = self.answer_len |
|
|
120 |
return self.get_gpt_answer(question, answer_len) |
|
|
121 |
|
|
|
122 |
def eval_func(self, question: str, answer: str) -> float: |
|
|
123 |
""" |
|
|
124 |
Evaluate generated answer against ground truth. |
|
|
125 |
|
|
|
126 |
Args: |
|
|
127 |
question (str): Input question. |
|
|
128 |
answer (str): Generated answer. |
|
|
129 |
|
|
|
130 |
Returns: |
|
|
131 |
float: BLEU score. |
|
|
132 |
""" |
|
|
133 |
answer_len = 20 |
|
|
134 |
generated_answer = self.get_gpt_answer(question, answer_len) |
|
|
135 |
reference = [answer.split(" ")] |
|
|
136 |
candidate = generated_answer.split(" ") |
|
|
137 |
score = sentence_bleu(reference, candidate) |
|
|
138 |
return score |
|
|
139 |
|
|
|
140 |
def run(self, question: str, isEval: bool) -> str: |
|
|
141 |
""" |
|
|
142 |
Run inference for the given question. |
|
|
143 |
|
|
|
144 |
Args: |
|
|
145 |
question (str): Input question. |
|
|
146 |
isEval (bool): Whether to evaluate or not. |
|
|
147 |
|
|
|
148 |
Returns: |
|
|
149 |
str: Generated answer. |
|
|
150 |
""" |
|
|
151 |
answer = self.inf_func(question) |
|
|
152 |
if isEval: |
|
|
153 |
bleu_score = self.eval_func(question, answer) |
|
|
154 |
print(f"The sentence_bleu score is {bleu_score}") |
|
|
155 |
return answer |