|
a |
|
b/main.py |
|
|
1 |
import sys |
|
|
2 |
import json |
|
|
3 |
import requests |
|
|
4 |
import openai |
|
|
5 |
from utils import get_values |
|
|
6 |
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QTextEdit, QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QMessageBox, QListWidget |
|
|
7 |
from PyQt5.QtGui import QPixmap, QFont |
|
|
8 |
from PyQt5.QtCore import Qt |
|
|
9 |
|
|
|
10 |
# OpenAI Key |
|
|
11 |
openai.api_key = "OPENAI_API_KEY" |
|
|
12 |
|
|
|
13 |
with open("schema.txt", "r") as schema_file: |
|
|
14 |
schema_prompt = schema_file.read() |
|
|
15 |
|
|
|
16 |
prime_query_prompt = "query top_n_associated_diseases {\n search(queryString:" |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class GeneticsGPTGUI(QMainWindow): |
|
|
20 |
def __init__(self): |
|
|
21 |
super().__init__() |
|
|
22 |
self.setWindowTitle("GeneticsGPT") |
|
|
23 |
self.setGeometry(100, 100, 800, 600) |
|
|
24 |
|
|
|
25 |
# main widget and layout |
|
|
26 |
central_widget = QWidget(self) |
|
|
27 |
main_layout = QHBoxLayout(central_widget) |
|
|
28 |
main_layout.setContentsMargins(20, 20, 20, 20) |
|
|
29 |
main_layout.setSpacing(20) |
|
|
30 |
self.setCentralWidget(central_widget) |
|
|
31 |
|
|
|
32 |
# left section layout |
|
|
33 |
left_layout = QVBoxLayout() |
|
|
34 |
left_layout.setSpacing(20) |
|
|
35 |
main_layout.addLayout(left_layout) |
|
|
36 |
|
|
|
37 |
# logo label |
|
|
38 |
logo_label = QLabel(self) |
|
|
39 |
logo_pixmap = QPixmap("logo.png") # Replace with your logo image file |
|
|
40 |
logo_pixmap = logo_pixmap.scaledToWidth( |
|
|
41 |
200) # Adjust the logo size as needed |
|
|
42 |
logo_label.setPixmap(logo_pixmap) |
|
|
43 |
logo_label.setAlignment(Qt.AlignCenter) |
|
|
44 |
left_layout.addWidget(logo_label) |
|
|
45 |
|
|
|
46 |
# question input field |
|
|
47 |
self.question_input_field = QLineEdit(self) |
|
|
48 |
self.question_input_field.setPlaceholderText("Ask a question...") |
|
|
49 |
self.question_input_field.setFont(QFont("Arial", 14)) |
|
|
50 |
self.question_input_field.setStyleSheet( |
|
|
51 |
"padding: 10px; border-radius: 5px;") |
|
|
52 |
left_layout.addWidget(self.question_input_field) |
|
|
53 |
|
|
|
54 |
# submit button |
|
|
55 |
self.submit_button = QPushButton("Submit", self) |
|
|
56 |
self.submit_button.clicked.connect(self.handle_submit) |
|
|
57 |
self.submit_button.setFont(QFont("Arial", 14)) |
|
|
58 |
self.submit_button.setStyleSheet( |
|
|
59 |
"padding: 10px; background-color: #4285F4; color: white; border-radius: 5px;") |
|
|
60 |
left_layout.addWidget(self.submit_button) |
|
|
61 |
|
|
|
62 |
# answer widget |
|
|
63 |
self.answer_widget = QTextEdit(self) |
|
|
64 |
self.answer_widget.setReadOnly(True) |
|
|
65 |
self.answer_widget.setFont(QFont("Arial", 12)) |
|
|
66 |
self.answer_widget.setStyleSheet( |
|
|
67 |
"background-color: #f0f0f0; color: black; padding: 10px; border-radius: 5px;") |
|
|
68 |
left_layout.addWidget(QLabel("Answers:")) |
|
|
69 |
left_layout.addWidget(self.answer_widget) |
|
|
70 |
|
|
|
71 |
# FAQ list |
|
|
72 |
self.faq_list = QListWidget(self) |
|
|
73 |
self.faq_list.itemClicked.connect(self.handle_faq_click) |
|
|
74 |
left_layout.addWidget(QLabel("Frequently Asked Questions:")) |
|
|
75 |
left_layout.addWidget(self.faq_list) |
|
|
76 |
|
|
|
77 |
# answer text area |
|
|
78 |
self.answer_text_area = QTextEdit(self) |
|
|
79 |
self.answer_text_area.setReadOnly(True) |
|
|
80 |
self.answer_text_area.setFont(QFont("Arial", 12)) |
|
|
81 |
self.answer_text_area.setStyleSheet( |
|
|
82 |
"background-color: #f0f0f0; color: black; padding: 10px; border-radius: 5px;") |
|
|
83 |
main_layout.addWidget(self.answer_text_area) |
|
|
84 |
|
|
|
85 |
def handle_submit(self): |
|
|
86 |
user_question = self.question_input_field.text() |
|
|
87 |
|
|
|
88 |
# prevent multiple requests |
|
|
89 |
self.submit_button.setEnabled(False) |
|
|
90 |
|
|
|
91 |
query_response = self.generate_query_response(user_question) |
|
|
92 |
|
|
|
93 |
self.answer_widget.setPlainText(query_response) |
|
|
94 |
|
|
|
95 |
# Generate FAQs based on the query response |
|
|
96 |
self.generate_faqs(query_response) |
|
|
97 |
|
|
|
98 |
self.submit_button.setEnabled(True) |
|
|
99 |
|
|
|
100 |
def handle_faq_click(self, item): |
|
|
101 |
faq_question = item.text() |
|
|
102 |
|
|
|
103 |
# Generate a response for the clicked FAQ |
|
|
104 |
self.generate_faq_response(faq_question) |
|
|
105 |
|
|
|
106 |
def generate_query_response(self, user_question): |
|
|
107 |
openai_response = openai.ChatCompletion.create( |
|
|
108 |
model="gpt-3.5-turbo", |
|
|
109 |
messages=[ |
|
|
110 |
{"role": "system", "content": schema_prompt}, |
|
|
111 |
{"role": "user", "content": user_question}, |
|
|
112 |
{"role": "system", "content": prime_query_prompt}, |
|
|
113 |
], |
|
|
114 |
temperature=0, |
|
|
115 |
max_tokens=250, |
|
|
116 |
top_p=1, |
|
|
117 |
frequency_penalty=0, |
|
|
118 |
presence_penalty=0, |
|
|
119 |
stop=["###"], |
|
|
120 |
) |
|
|
121 |
generated_query = openai_response["choices"][0].message["content"] |
|
|
122 |
|
|
|
123 |
graphql_query = prime_query_prompt + generated_query |
|
|
124 |
|
|
|
125 |
# Set base URL of GraphQL API endpoint |
|
|
126 |
api_url = "https://api.platform.opentargets.org/api/v4/graphql" |
|
|
127 |
|
|
|
128 |
# Perform POST request and check status code of response |
|
|
129 |
try: |
|
|
130 |
api_response = requests.post( |
|
|
131 |
api_url, json={"query": graphql_query}) |
|
|
132 |
api_response.raise_for_status() |
|
|
133 |
except requests.exceptions.HTTPError as err: |
|
|
134 |
print(err) |
|
|
135 |
QMessageBox.critical( |
|
|
136 |
self, "Error", "An error occurred while fetching data from the API.") |
|
|
137 |
return None |
|
|
138 |
|
|
|
139 |
# Transform API response from JSON into Python dictionary |
|
|
140 |
api_data = json.loads(api_response.text) |
|
|
141 |
|
|
|
142 |
try: |
|
|
143 |
search_hits = api_data["data"]["search"]["hits"][0] |
|
|
144 |
except (KeyError, IndexError): |
|
|
145 |
QMessageBox.warning( |
|
|
146 |
self, "Warning", "No results found for the given query.") |
|
|
147 |
return None |
|
|
148 |
|
|
|
149 |
diseases = get_values(search_hits, "disease") |
|
|
150 |
answer_text = "\n".join( |
|
|
151 |
f"{i+1}. {disease['name']}" for i, disease in enumerate(diseases)) |
|
|
152 |
|
|
|
153 |
return answer_text |
|
|
154 |
|
|
|
155 |
def generate_faqs(self, query_response): |
|
|
156 |
if query_response: |
|
|
157 |
faq_prompt = f"Based on the following information:\n\n{query_response}\n\nGenerate 3-5 relevant frequently asked questions (FAQs) related to the diseases and genes mentioned. Provide each FAQ as a question only, without the 'Q:' prefix or any additional context." |
|
|
158 |
|
|
|
159 |
openai_response = openai.ChatCompletion.create( |
|
|
160 |
model="gpt-3.5-turbo", |
|
|
161 |
messages=[ |
|
|
162 |
{"role": "system", "content": schema_prompt}, |
|
|
163 |
{"role": "user", "content": faq_prompt}, |
|
|
164 |
], |
|
|
165 |
temperature=0.7, |
|
|
166 |
max_tokens=200, |
|
|
167 |
top_p=1, |
|
|
168 |
frequency_penalty=0, |
|
|
169 |
presence_penalty=0, |
|
|
170 |
) |
|
|
171 |
|
|
|
172 |
generated_faqs = openai_response["choices"][0].message["content"] |
|
|
173 |
|
|
|
174 |
self.faq_list.clear() |
|
|
175 |
|
|
|
176 |
# Add generated FAQs to the list |
|
|
177 |
for faq in generated_faqs.split("\n"): |
|
|
178 |
self.faq_list.addItem(faq) |
|
|
179 |
|
|
|
180 |
def generate_faq_response(self, faq_question): |
|
|
181 |
faq_prompt = f"Q: {faq_question}\nA: Provide a concise and informative answer to the question, focusing on the key points and avoiding unnecessary details or technical jargon." |
|
|
182 |
|
|
|
183 |
openai_response = openai.ChatCompletion.create( |
|
|
184 |
model="gpt-3.5-turbo", |
|
|
185 |
messages=[ |
|
|
186 |
{"role": "system", "content": schema_prompt}, |
|
|
187 |
{"role": "user", "content": faq_prompt}, |
|
|
188 |
], |
|
|
189 |
temperature=0.7, |
|
|
190 |
max_tokens=200, |
|
|
191 |
top_p=1, |
|
|
192 |
frequency_penalty=0, |
|
|
193 |
presence_penalty=0, |
|
|
194 |
) |
|
|
195 |
|
|
|
196 |
faq_response = openai_response["choices"][0].message["content"] |
|
|
197 |
self.answer_text_area.setPlainText(faq_response) |
|
|
198 |
|
|
|
199 |
|
|
|
200 |
if __name__ == "__main__": |
|
|
201 |
app = QApplication(sys.argv) |
|
|
202 |
gui = GeneticsGPTGUI() |
|
|
203 |
gui.show() |
|
|
204 |
sys.exit(app.exec_()) |