a b/main.py
1
import tensorflow as tf
2
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, AutoTokenizer, TFAutoModel
3
from modules.chatbot.inferencer import Inferencer
4
from modules.chatbot.dataloader import convert, get_bert_index, get_dataset
5
from modules.chatbot.config import Config as CONF
6
from colorama import Fore, Back, Style
7
import warnings
8
import logging
9
10
warnings.filterwarnings("ignore")
11
logging.basicConfig(level=logging.CRITICAL)
12
13
14
def main():
15
    # Load the chatbot model from the config.
16
    gpt2_tokenizer = GPT2Tokenizer.from_pretrained(CONF.chat_params["gpt_tok"])
17
    medi_qa_chatGPT2 = TFGPT2LMHeadModel.from_pretrained(
18
        CONF.chat_params["tf_gpt_model"]
19
    )
20
    biobert_tokenizer = AutoTokenizer.from_pretrained(CONF.chat_params["bert_tok"])
21
    try:
22
        question_extractor_model_v1 = tf.keras.models.load_model(
23
            CONF.chat_params["tf_q_extractor"]
24
        )
25
    except Exception as e:
26
        print(e)
27
28
    df_qa = get_dataset(CONF.chat_params["data"])
29
    max_answer_len = CONF.chat_params["max_answer_len"]
30
    isEval = CONF.chat_params["isEval"]
31
32
    # Get answer index from Answer from FFNN embedding column.
33
    answer_index = get_bert_index(df_qa, "A_FFNN_embeds")
34
35
    # Make chatbot inference object
36
    cahtbot = Inferencer(
37
        medi_qa_chatGPT2,
38
        biobert_tokenizer,
39
        gpt2_tokenizer,
40
        question_extractor_model_v1,
41
        df_qa,
42
        answer_index,
43
        max_answer_len,
44
    )
45
46
    # Start chatbot
47
    print("========================================")
48
    print(Back.BLUE + "          Welcome to MediChatBot        " + Back.RESET)
49
    print("========================================")
50
    print("If you enter quit, q, stop, chat will be ended.")
51
    print(
52
        "MediChatBot v1 is not an official service and is not responsible for any usage."
53
    )
54
    print(
55
        "Please enter your message below.\nThis chatbot is not sufficiently trained and the dataset is not properly cleaned, so it does not have a meaning beyond the demo version."
56
    )
57
58
    # Chat
59
    while True:
60
        user_input = input(Fore.BLUE + "You: " + Fore.RESET)
61
        if user_input.lower() in ["quit", "q", "stop"]:
62
            print("========================================")
63
            print(
64
                Fore.RED
65
                + "              Chat Ended.          "
66
                + Fore.RESET
67
                + "\n\nThank you for using DSDanielPark's chatbot. Please visit our GitHub and Hugging Face for more information. \n\n - github: https://github.com/DSDanielPark/GPT-BERT-Medical-QA-Chatbot \n - hugging-face: https://huggingface.co/datasets/danielpark/MQuAD-v1 "
68
            )
69
            print("========================================")
70
            break
71
72
        response = cahtbot.run(user_input, isEval)
73
        print(
74
            Fore.BLUE
75
            + Style.BRIGHT
76
            + "MediChatBot: "
77
            + response
78
            + Fore.RESET
79
            + Style.RESET_ALL
80
        )
81
        response = ""
82
83
84
if __name__ == "__main__":
85
    main()