Diff of /main.py [000000] .. [2003ef]

Switch to unified view

a b/main.py
1
import sys
2
import json
3
import requests
4
import openai
5
from utils import get_values
6
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QTextEdit, QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QMessageBox, QListWidget
7
from PyQt5.QtGui import QPixmap, QFont
8
from PyQt5.QtCore import Qt
9
10
# OpenAI Key
11
openai.api_key = "OPENAI_API_KEY"
12
13
with open("schema.txt", "r") as schema_file:
14
    schema_prompt = schema_file.read()
15
16
prime_query_prompt = "query top_n_associated_diseases {\n  search(queryString:"
17
18
19
class GeneticsGPTGUI(QMainWindow):
20
    def __init__(self):
21
        super().__init__()
22
        self.setWindowTitle("GeneticsGPT")
23
        self.setGeometry(100, 100, 800, 600)
24
25
        # main widget and layout
26
        central_widget = QWidget(self)
27
        main_layout = QHBoxLayout(central_widget)
28
        main_layout.setContentsMargins(20, 20, 20, 20)
29
        main_layout.setSpacing(20)
30
        self.setCentralWidget(central_widget)
31
32
        # left section layout
33
        left_layout = QVBoxLayout()
34
        left_layout.setSpacing(20)
35
        main_layout.addLayout(left_layout)
36
37
        # logo label
38
        logo_label = QLabel(self)
39
        logo_pixmap = QPixmap("logo.png")  # Replace with your logo image file
40
        logo_pixmap = logo_pixmap.scaledToWidth(
41
            200)  # Adjust the logo size as needed
42
        logo_label.setPixmap(logo_pixmap)
43
        logo_label.setAlignment(Qt.AlignCenter)
44
        left_layout.addWidget(logo_label)
45
46
        # question input field
47
        self.question_input_field = QLineEdit(self)
48
        self.question_input_field.setPlaceholderText("Ask a question...")
49
        self.question_input_field.setFont(QFont("Arial", 14))
50
        self.question_input_field.setStyleSheet(
51
            "padding: 10px; border-radius: 5px;")
52
        left_layout.addWidget(self.question_input_field)
53
54
        # submit button
55
        self.submit_button = QPushButton("Submit", self)
56
        self.submit_button.clicked.connect(self.handle_submit)
57
        self.submit_button.setFont(QFont("Arial", 14))
58
        self.submit_button.setStyleSheet(
59
            "padding: 10px; background-color: #4285F4; color: white; border-radius: 5px;")
60
        left_layout.addWidget(self.submit_button)
61
62
        # answer widget
63
        self.answer_widget = QTextEdit(self)
64
        self.answer_widget.setReadOnly(True)
65
        self.answer_widget.setFont(QFont("Arial", 12))
66
        self.answer_widget.setStyleSheet(
67
            "background-color: #f0f0f0; color: black; padding: 10px; border-radius: 5px;")
68
        left_layout.addWidget(QLabel("Answers:"))
69
        left_layout.addWidget(self.answer_widget)
70
71
        # FAQ list
72
        self.faq_list = QListWidget(self)
73
        self.faq_list.itemClicked.connect(self.handle_faq_click)
74
        left_layout.addWidget(QLabel("Frequently Asked Questions:"))
75
        left_layout.addWidget(self.faq_list)
76
77
        # answer text area
78
        self.answer_text_area = QTextEdit(self)
79
        self.answer_text_area.setReadOnly(True)
80
        self.answer_text_area.setFont(QFont("Arial", 12))
81
        self.answer_text_area.setStyleSheet(
82
            "background-color: #f0f0f0; color: black; padding: 10px; border-radius: 5px;")
83
        main_layout.addWidget(self.answer_text_area)
84
85
    def handle_submit(self):
86
        user_question = self.question_input_field.text()
87
88
        # prevent multiple requests
89
        self.submit_button.setEnabled(False)
90
91
        query_response = self.generate_query_response(user_question)
92
93
        self.answer_widget.setPlainText(query_response)
94
95
        # Generate FAQs based on the query response
96
        self.generate_faqs(query_response)
97
98
        self.submit_button.setEnabled(True)
99
100
    def handle_faq_click(self, item):
101
        faq_question = item.text()
102
103
        # Generate a response for the clicked FAQ
104
        self.generate_faq_response(faq_question)
105
106
    def generate_query_response(self, user_question):
107
        openai_response = openai.ChatCompletion.create(
108
            model="gpt-3.5-turbo",
109
            messages=[
110
                {"role": "system", "content": schema_prompt},
111
                {"role": "user", "content": user_question},
112
                {"role": "system", "content": prime_query_prompt},
113
            ],
114
            temperature=0,
115
            max_tokens=250,
116
            top_p=1,
117
            frequency_penalty=0,
118
            presence_penalty=0,
119
            stop=["###"],
120
        )
121
        generated_query = openai_response["choices"][0].message["content"]
122
123
        graphql_query = prime_query_prompt + generated_query
124
125
        # Set base URL of GraphQL API endpoint
126
        api_url = "https://api.platform.opentargets.org/api/v4/graphql"
127
128
        # Perform POST request and check status code of response
129
        try:
130
            api_response = requests.post(
131
                api_url, json={"query": graphql_query})
132
            api_response.raise_for_status()
133
        except requests.exceptions.HTTPError as err:
134
            print(err)
135
            QMessageBox.critical(
136
                self, "Error", "An error occurred while fetching data from the API.")
137
            return None
138
139
        # Transform API response from JSON into Python dictionary
140
        api_data = json.loads(api_response.text)
141
142
        try:
143
            search_hits = api_data["data"]["search"]["hits"][0]
144
        except (KeyError, IndexError):
145
            QMessageBox.warning(
146
                self, "Warning", "No results found for the given query.")
147
            return None
148
149
        diseases = get_values(search_hits, "disease")
150
        answer_text = "\n".join(
151
            f"{i+1}. {disease['name']}" for i, disease in enumerate(diseases))
152
153
        return answer_text
154
155
    def generate_faqs(self, query_response):
156
        if query_response:
157
            faq_prompt = f"Based on the following information:\n\n{query_response}\n\nGenerate 3-5 relevant frequently asked questions (FAQs) related to the diseases and genes mentioned. Provide each FAQ as a question only, without the 'Q:' prefix or any additional context."
158
159
            openai_response = openai.ChatCompletion.create(
160
                model="gpt-3.5-turbo",
161
                messages=[
162
                    {"role": "system", "content": schema_prompt},
163
                    {"role": "user", "content": faq_prompt},
164
                ],
165
                temperature=0.7,
166
                max_tokens=200,
167
                top_p=1,
168
                frequency_penalty=0,
169
                presence_penalty=0,
170
            )
171
172
            generated_faqs = openai_response["choices"][0].message["content"]
173
174
            self.faq_list.clear()
175
176
            # Add generated FAQs to the list
177
            for faq in generated_faqs.split("\n"):
178
                self.faq_list.addItem(faq)
179
180
    def generate_faq_response(self, faq_question):
181
        faq_prompt = f"Q: {faq_question}\nA: Provide a concise and informative answer to the question, focusing on the key points and avoiding unnecessary details or technical jargon."
182
183
        openai_response = openai.ChatCompletion.create(
184
            model="gpt-3.5-turbo",
185
            messages=[
186
                {"role": "system", "content": schema_prompt},
187
                {"role": "user", "content": faq_prompt},
188
            ],
189
            temperature=0.7,
190
            max_tokens=200,
191
            top_p=1,
192
            frequency_penalty=0,
193
            presence_penalty=0,
194
        )
195
196
        faq_response = openai_response["choices"][0].message["content"]
197
        self.answer_text_area.setPlainText(faq_response)
198
199
200
if __name__ == "__main__":
201
    app = QApplication(sys.argv)
202
    gui = GeneticsGPTGUI()
203
    gui.show()
204
    sys.exit(app.exec_())