Diff of /chatbot.py [000000] .. [e1b945]

Switch to side-by-side view

--- a
+++ b/chatbot.py
@@ -0,0 +1,99 @@
+import streamlit as st
+import tensorflow as tf
+from streamlit_chat import message
+from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, AutoTokenizer
+from modules.chatbot.inferencer import Inferencer
+from modules.chatbot.dataloader import get_bert_index, get_dataset
+from modules.chatbot.config import Config as CONF
+from utilfunction import find_path
+
+# Streamlit App
+st.header("GPT-BERT-Medical-QA-Chatbot")
+
+# Load necessary models and data
+gpt2_tokenizer = GPT2Tokenizer.from_pretrained(CONF.chat_params["gpt_tok"])
+medi_qa_chatGPT2 = TFGPT2LMHeadModel.from_pretrained(CONF.chat_params["tf_gpt_model"])
+biobert_tokenizer = AutoTokenizer.from_pretrained(CONF.chat_params["bert_tok"])
+df_qa = get_dataset(CONF.chat_params["data"])
+max_answer_len = CONF.chat_params["max_answer_len"]
+isEval = CONF.chat_params["isEval"]
+answer_index = get_bert_index(df_qa, "A_FFNN_embeds")
+
+
+# Load question extractor model
+@st.cache_resource
+def load_tf_model(path):
+    return tf.keras.models.load_model(path)
+
+
+try:
+    if CONF.chat_params["runDocker"]:
+        tf_q_extractor_path = find_path(
+            CONF.chat_params["container_mounted_folder_path"],
+            "folder",
+            "question_extractor_model",
+        )
+        question_extractor_model_v1 = load_tf_model(tf_q_extractor_path[0])
+    else:
+        question_extractor_model_v1 = load_tf_model(CONF.chat_params["tf_q_extractor"])
+except Exception as e:
+    tf_q_extractor_path = find_path("./", "folder", "question_extractor_model")
+    question_extractor_model_v1 = load_tf_model(tf_q_extractor_path[0])
+
+# Initialize chatbot inferencer
+chatbot = Inferencer(
+    medi_qa_chatGPT2,
+    biobert_tokenizer,
+    gpt2_tokenizer,
+    question_extractor_model_v1,
+    df_qa,
+    answer_index,
+    max_answer_len,
+)
+
+
+# Function to get model's answer
+def get_model_answer(chatbot, user_input):
+    return chatbot.run(user_input, isEval)
+
+
+# Function to interact with chatbot
+def chatgpt(input, history):
+    history = history or []
+    output = get_model_answer(chatbot, input)
+    history.append(output)
+    return history
+
+
+# Maintain user input history
+history_input = []
+if "generated" not in st.session_state:
+    st.session_state["generated"] = []
+if "past" not in st.session_state:
+    st.session_state["past"] = []
+
+
+# Function to get user input
+def get_text():
+    input_text = st.text_input("You: ", key="input")
+    return input_text
+
+
+# Main interaction loop
+user_input = get_text()
+
+if user_input:
+    output = chatgpt(user_input, history_input)
+    history_input.append(output)
+    st.session_state.past.append(user_input)
+    st.session_state.generated.append(output[0])
+
+if st.session_state["generated"]:
+    for i in range(len(st.session_state["generated"]) - 1, -1, -1):
+        message(st.session_state["generated"][i], key=str(i), avatar_style="thumbs")
+        message(
+            st.session_state["past"][i],
+            is_user=True,
+            key=str(i) + "_user",
+            avatar_style="thumbs",
+        )