a b/chatbot.py
1
import streamlit as st
2
import tensorflow as tf
3
from streamlit_chat import message
4
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, AutoTokenizer
5
from modules.chatbot.inferencer import Inferencer
6
from modules.chatbot.dataloader import get_bert_index, get_dataset
7
from modules.chatbot.config import Config as CONF
8
from utilfunction import find_path
9
10
# Streamlit App
11
st.header("GPT-BERT-Medical-QA-Chatbot")
12
13
# Load necessary models and data
14
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(CONF.chat_params["gpt_tok"])
15
medi_qa_chatGPT2 = TFGPT2LMHeadModel.from_pretrained(CONF.chat_params["tf_gpt_model"])
16
biobert_tokenizer = AutoTokenizer.from_pretrained(CONF.chat_params["bert_tok"])
17
df_qa = get_dataset(CONF.chat_params["data"])
18
max_answer_len = CONF.chat_params["max_answer_len"]
19
isEval = CONF.chat_params["isEval"]
20
answer_index = get_bert_index(df_qa, "A_FFNN_embeds")
21
22
23
# Load question extractor model
24
@st.cache_resource
25
def load_tf_model(path):
26
    return tf.keras.models.load_model(path)
27
28
29
try:
30
    if CONF.chat_params["runDocker"]:
31
        tf_q_extractor_path = find_path(
32
            CONF.chat_params["container_mounted_folder_path"],
33
            "folder",
34
            "question_extractor_model",
35
        )
36
        question_extractor_model_v1 = load_tf_model(tf_q_extractor_path[0])
37
    else:
38
        question_extractor_model_v1 = load_tf_model(CONF.chat_params["tf_q_extractor"])
39
except Exception as e:
40
    tf_q_extractor_path = find_path("./", "folder", "question_extractor_model")
41
    question_extractor_model_v1 = load_tf_model(tf_q_extractor_path[0])
42
43
# Initialize chatbot inferencer
44
chatbot = Inferencer(
45
    medi_qa_chatGPT2,
46
    biobert_tokenizer,
47
    gpt2_tokenizer,
48
    question_extractor_model_v1,
49
    df_qa,
50
    answer_index,
51
    max_answer_len,
52
)
53
54
55
# Function to get model's answer
56
def get_model_answer(chatbot, user_input):
57
    return chatbot.run(user_input, isEval)
58
59
60
# Function to interact with chatbot
61
def chatgpt(input, history):
62
    history = history or []
63
    output = get_model_answer(chatbot, input)
64
    history.append(output)
65
    return history
66
67
68
# Maintain user input history
69
history_input = []
70
if "generated" not in st.session_state:
71
    st.session_state["generated"] = []
72
if "past" not in st.session_state:
73
    st.session_state["past"] = []
74
75
76
# Function to get user input
77
def get_text():
78
    input_text = st.text_input("You: ", key="input")
79
    return input_text
80
81
82
# Main interaction loop
83
user_input = get_text()
84
85
if user_input:
86
    output = chatgpt(user_input, history_input)
87
    history_input.append(output)
88
    st.session_state.past.append(user_input)
89
    st.session_state.generated.append(output[0])
90
91
if st.session_state["generated"]:
92
    for i in range(len(st.session_state["generated"]) - 1, -1, -1):
93
        message(st.session_state["generated"][i], key=str(i), avatar_style="thumbs")
94
        message(
95
            st.session_state["past"][i],
96
            is_user=True,
97
            key=str(i) + "_user",
98
            avatar_style="thumbs",
99
        )