|
a |
|
b/app/frontend/utils/streamlit_utils.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
""" |
|
|
4 |
Utils for Streamlit. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import os |
|
|
8 |
import datetime |
|
|
9 |
import hydra |
|
|
10 |
import tempfile |
|
|
11 |
import streamlit as st |
|
|
12 |
import streamlit.components.v1 as components |
|
|
13 |
import pandas as pd |
|
|
14 |
import plotly.express as px |
|
|
15 |
from langsmith import Client |
|
|
16 |
from langchain_ollama import ChatOllama |
|
|
17 |
from langchain_openai import ChatOpenAI |
|
|
18 |
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings |
|
|
19 |
from langchain_openai.embeddings import OpenAIEmbeddings |
|
|
20 |
from langchain_core.language_models import BaseChatModel |
|
|
21 |
from langchain_core.embeddings import Embeddings |
|
|
22 |
from langchain_core.messages import AIMessageChunk, HumanMessage, ChatMessage, AIMessage |
|
|
23 |
from langchain_core.tracers.context import collect_runs |
|
|
24 |
from langchain.callbacks.tracers import LangChainTracer |
|
|
25 |
import networkx as nx |
|
|
26 |
import gravis |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def submit_feedback(user_response): |
|
|
30 |
""" |
|
|
31 |
Function to submit feedback to the developers. |
|
|
32 |
|
|
|
33 |
Args: |
|
|
34 |
user_response: dict: The user response |
|
|
35 |
""" |
|
|
36 |
client = Client() |
|
|
37 |
client.create_feedback( |
|
|
38 |
st.session_state.run_id, |
|
|
39 |
key="feedback", |
|
|
40 |
score=1 if user_response["score"] == "👍" else 0, |
|
|
41 |
comment=user_response["text"], |
|
|
42 |
) |
|
|
43 |
st.info("Your feedback is on its way to the developers. Thank you!", icon="🚀") |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
def render_table_plotly( |
|
|
47 |
uniq_msg_id, content, df_selected, x_axis_label="Time", y_axis_label="Concentration" |
|
|
48 |
): |
|
|
49 |
""" |
|
|
50 |
Function to render the table and plotly chart in the chat. |
|
|
51 |
|
|
|
52 |
Args: |
|
|
53 |
uniq_msg_id: str: The unique message id |
|
|
54 |
msg: dict: The message object |
|
|
55 |
df_selected: pd.DataFrame: The selected dataframe |
|
|
56 |
""" |
|
|
57 |
# Display the toggle button to suppress the table |
|
|
58 |
render_toggle( |
|
|
59 |
key="toggle_plotly_" + uniq_msg_id, |
|
|
60 |
toggle_text="Show Plot", |
|
|
61 |
toggle_state=True, |
|
|
62 |
save_toggle=True, |
|
|
63 |
) |
|
|
64 |
# Display the plotly chart |
|
|
65 |
render_plotly( |
|
|
66 |
df_selected, |
|
|
67 |
key="plotly_" + uniq_msg_id, |
|
|
68 |
title=content, |
|
|
69 |
y_axis_label=y_axis_label, |
|
|
70 |
x_axis_label=x_axis_label, |
|
|
71 |
save_chart=True, |
|
|
72 |
) |
|
|
73 |
# Display the toggle button to suppress the table |
|
|
74 |
render_toggle( |
|
|
75 |
key="toggle_table_" + uniq_msg_id, |
|
|
76 |
toggle_text="Show Table", |
|
|
77 |
toggle_state=False, |
|
|
78 |
save_toggle=True, |
|
|
79 |
) |
|
|
80 |
# Display the table |
|
|
81 |
render_table(df_selected, key="dataframe_" + uniq_msg_id, save_table=True) |
|
|
82 |
st.empty() |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
def render_toggle( |
|
|
86 |
key: str, toggle_text: str, toggle_state: bool, save_toggle: bool = False |
|
|
87 |
): |
|
|
88 |
""" |
|
|
89 |
Function to render the toggle button to show/hide the table. |
|
|
90 |
|
|
|
91 |
Args: |
|
|
92 |
key: str: The key for the toggle button |
|
|
93 |
toggle_text: str: The text for the toggle button |
|
|
94 |
toggle_state: bool: The state of the toggle button |
|
|
95 |
save_toggle: bool: Flag to save the toggle button to the chat history |
|
|
96 |
""" |
|
|
97 |
st.toggle(toggle_text, toggle_state, help="""Toggle to show/hide data""", key=key) |
|
|
98 |
# print (key) |
|
|
99 |
if save_toggle: |
|
|
100 |
# Add data to the chat history |
|
|
101 |
st.session_state.messages.append( |
|
|
102 |
{ |
|
|
103 |
"type": "toggle", |
|
|
104 |
"content": toggle_text, |
|
|
105 |
"toggle_state": toggle_state, |
|
|
106 |
"key": key, |
|
|
107 |
} |
|
|
108 |
) |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
def render_plotly( |
|
|
112 |
df: pd.DataFrame, |
|
|
113 |
key: str, |
|
|
114 |
title: str, |
|
|
115 |
y_axis_label: str, |
|
|
116 |
x_axis_label: str, |
|
|
117 |
save_chart: bool = False, |
|
|
118 |
): |
|
|
119 |
""" |
|
|
120 |
Function to visualize the dataframe using Plotly. |
|
|
121 |
|
|
|
122 |
Args: |
|
|
123 |
df: pd.DataFrame: The input dataframe |
|
|
124 |
key: str: The key for the plotly chart |
|
|
125 |
title: str: The title of the plotly chart |
|
|
126 |
save_chart: bool: Flag to save the chart to the chat history |
|
|
127 |
""" |
|
|
128 |
# toggle_state = st.session_state[f'toggle_plotly_{tool_name}_{key.split("_")[-1]}']\ |
|
|
129 |
toggle_state = st.session_state[f'toggle_plotly_{key.split("plotly_")[1]}'] |
|
|
130 |
if toggle_state: |
|
|
131 |
df_simulation_results = df.melt( |
|
|
132 |
id_vars="Time", var_name="Species", value_name="Concentration" |
|
|
133 |
) |
|
|
134 |
fig = px.line( |
|
|
135 |
df_simulation_results, |
|
|
136 |
x="Time", |
|
|
137 |
y="Concentration", |
|
|
138 |
color="Species", |
|
|
139 |
title=title, |
|
|
140 |
height=500, |
|
|
141 |
width=600, |
|
|
142 |
) |
|
|
143 |
# Set y axis label |
|
|
144 |
fig.update_yaxes(title_text=f"Quantity ({y_axis_label})") |
|
|
145 |
# Set x axis label |
|
|
146 |
fig.update_xaxes(title_text=f"Time ({x_axis_label})") |
|
|
147 |
# Display the plotly chart |
|
|
148 |
st.plotly_chart(fig, use_container_width=True, key=key) |
|
|
149 |
if save_chart: |
|
|
150 |
# Add data to the chat history |
|
|
151 |
st.session_state.messages.append( |
|
|
152 |
{ |
|
|
153 |
"type": "plotly", |
|
|
154 |
"content": df, |
|
|
155 |
"key": key, |
|
|
156 |
"title": title, |
|
|
157 |
"y_axis_label": y_axis_label, |
|
|
158 |
"x_axis_label": x_axis_label, |
|
|
159 |
# "tool_name": tool_name |
|
|
160 |
} |
|
|
161 |
) |
|
|
162 |
|
|
|
163 |
|
|
|
164 |
def render_table(df: pd.DataFrame, key: str, save_table: bool = False): |
|
|
165 |
""" |
|
|
166 |
Function to render the table in the chat. |
|
|
167 |
|
|
|
168 |
Args: |
|
|
169 |
df: pd.DataFrame: The input dataframe |
|
|
170 |
key: str: The key for the table |
|
|
171 |
save_table: bool: Flag to save the table to the chat history |
|
|
172 |
""" |
|
|
173 |
# print (st.session_state['toggle_simulate_model_'+key.split("_")[-1]]) |
|
|
174 |
# toggle_state = st.session_state[f'toggle_table_{tool_name}_{key.split("_")[-1]}'] |
|
|
175 |
toggle_state = st.session_state[f'toggle_table_{key.split("dataframe_")[1]}'] |
|
|
176 |
if toggle_state: |
|
|
177 |
st.dataframe(df, use_container_width=True, key=key) |
|
|
178 |
if save_table: |
|
|
179 |
# Add data to the chat history |
|
|
180 |
st.session_state.messages.append( |
|
|
181 |
{ |
|
|
182 |
"type": "dataframe", |
|
|
183 |
"content": df, |
|
|
184 |
"key": key, |
|
|
185 |
# "tool_name": tool_name |
|
|
186 |
} |
|
|
187 |
) |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
def sample_questions(): |
|
|
191 |
""" |
|
|
192 |
Function to get the sample questions. |
|
|
193 |
""" |
|
|
194 |
questions = [ |
|
|
195 |
'Search for all biomodels on "Crohns Disease"', |
|
|
196 |
"Briefly describe biomodel 971 and simulate it for 50 days with an interval of 50.", |
|
|
197 |
"Bring biomodel 27 to a steady state, and then " |
|
|
198 |
"determine the Mpp concentration at the steady state.", |
|
|
199 |
"How will the concentration of Mpp change in model 27, " |
|
|
200 |
"if the initial value of MAPKK were to be changed between 1 and 100 in steps of 10?", |
|
|
201 |
"Show annotations of all interleukins in model 537", |
|
|
202 |
] |
|
|
203 |
return questions |
|
|
204 |
|
|
|
205 |
|
|
|
206 |
def sample_questions_t2s(): |
|
|
207 |
""" |
|
|
208 |
Function to get the sample questions for Talk2Scholars. |
|
|
209 |
""" |
|
|
210 |
questions = [ |
|
|
211 |
'Search articles on "Role of DNA damage response (DDR) in Cancer"', |
|
|
212 |
"Save these articles in my Zotero library under the collection 'Curiosity'", |
|
|
213 |
"Tell me more about the first article in the last search results", |
|
|
214 |
"Download the article 'Attention is All You Need'", |
|
|
215 |
"Describe the methods of the downloaded paper", |
|
|
216 |
] |
|
|
217 |
return questions |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
def sample_questions_t2aa4p(): |
|
|
221 |
""" |
|
|
222 |
Function to get the sample questions for Talk2AIAgents4Pharma. |
|
|
223 |
""" |
|
|
224 |
questions = [ |
|
|
225 |
'Search for all the biomodels on "Crohns Disease"', |
|
|
226 |
"Briefly describe biomodel 537 and simulate it for 2016 hours with an interval of 100.", |
|
|
227 |
"List the drugs that target Interleukin-6", |
|
|
228 |
"What genes are associated with Crohn's disease?", |
|
|
229 |
] |
|
|
230 |
return questions |
|
|
231 |
|
|
|
232 |
|
|
|
233 |
def stream_response(response): |
|
|
234 |
""" |
|
|
235 |
Function to stream the response from the agent. |
|
|
236 |
|
|
|
237 |
Args: |
|
|
238 |
response: dict: The response from the agent |
|
|
239 |
""" |
|
|
240 |
agent_responding = False |
|
|
241 |
for chunk in response: |
|
|
242 |
# Stream only the AIMessageChunk |
|
|
243 |
if not isinstance(chunk[0], AIMessageChunk): |
|
|
244 |
continue |
|
|
245 |
# print (chunk[0].content, chunk[1]) |
|
|
246 |
# Exclude the tool calls that are not part of the conversation |
|
|
247 |
# if "branch:agent:should_continue:tools" not in chunk[1]["langgraph_triggers"]: |
|
|
248 |
# if chunk[1]["checkpoint_ns"].startswith("supervisor"): |
|
|
249 |
# continue |
|
|
250 |
if chunk[1]["checkpoint_ns"].startswith("supervisor") is False: |
|
|
251 |
agent_responding = True |
|
|
252 |
if "branch:to:agent" in chunk[1]["langgraph_triggers"]: |
|
|
253 |
if chunk[0].content == "": |
|
|
254 |
yield "\n" |
|
|
255 |
yield chunk[0].content |
|
|
256 |
else: |
|
|
257 |
# If no agent has responded yet |
|
|
258 |
# and the message is from the supervisor |
|
|
259 |
# then display the message |
|
|
260 |
if agent_responding is False: |
|
|
261 |
if "branch:to:agent" in chunk[1]["langgraph_triggers"]: |
|
|
262 |
if chunk[0].content == "": |
|
|
263 |
yield "\n" |
|
|
264 |
yield chunk[0].content |
|
|
265 |
# if "tools" in chunk[1]["langgraph_triggers"]: |
|
|
266 |
# agent_responded = True |
|
|
267 |
# if chunk[0].content == "": |
|
|
268 |
# yield "\n" |
|
|
269 |
# yield chunk[0].content |
|
|
270 |
# if agent_responding: |
|
|
271 |
# continue |
|
|
272 |
# if "branch:to:agent" in chunk[1]["langgraph_triggers"]: |
|
|
273 |
# if chunk[0].content == "": |
|
|
274 |
# yield "\n" |
|
|
275 |
# yield chunk[0].content |
|
|
276 |
|
|
|
277 |
|
|
|
278 |
def update_state_t2b(st): |
|
|
279 |
dic = { |
|
|
280 |
"sbml_file_path": [st.session_state.sbml_file_path], |
|
|
281 |
"text_embedding_model": get_text_embedding_model( |
|
|
282 |
st.session_state.text_embedding_model |
|
|
283 |
), |
|
|
284 |
} |
|
|
285 |
return dic |
|
|
286 |
|
|
|
287 |
|
|
|
288 |
def update_state_t2kg(st): |
|
|
289 |
dic = { |
|
|
290 |
"embedding_model": get_text_embedding_model( |
|
|
291 |
st.session_state.text_embedding_model |
|
|
292 |
), |
|
|
293 |
"uploaded_files": st.session_state.uploaded_files, |
|
|
294 |
"topk_nodes": st.session_state.topk_nodes, |
|
|
295 |
"topk_edges": st.session_state.topk_edges, |
|
|
296 |
"dic_source_graph": [ |
|
|
297 |
{ |
|
|
298 |
"name": st.session_state.config["kg_name"], |
|
|
299 |
"kg_pyg_path": st.session_state.config["kg_pyg_path"], |
|
|
300 |
"kg_text_path": st.session_state.config["kg_text_path"], |
|
|
301 |
} |
|
|
302 |
], |
|
|
303 |
} |
|
|
304 |
return dic |
|
|
305 |
|
|
|
306 |
|
|
|
307 |
def get_ai_messages(current_state): |
|
|
308 |
last_msg_is_human = False |
|
|
309 |
# If only supervisor answered i.e. no agent was called |
|
|
310 |
if isinstance(current_state.values["messages"][-2], HumanMessage): |
|
|
311 |
# msgs_to_consider = current_state.values["messages"] |
|
|
312 |
last_msg_is_human = True |
|
|
313 |
# else: |
|
|
314 |
# # If agent answered i.e. ignore the supervisor msg |
|
|
315 |
# msgs_to_consider = current_state.values["messages"][:-1] |
|
|
316 |
msgs_to_consider = current_state.values["messages"] |
|
|
317 |
# Get all the AI msgs in the |
|
|
318 |
# last response from the state |
|
|
319 |
assistant_content = [] |
|
|
320 |
# print ('LEN:', len(current_state.values["messages"][:-1])) |
|
|
321 |
# print (current_state.values["messages"][-2]) |
|
|
322 |
# Variable to check if the last message is from the "supervisor" |
|
|
323 |
# Supervisor message exists for agents that have sub-agents |
|
|
324 |
# In such cases, the last message is from the supervisor |
|
|
325 |
# and that is the message to be displayed to the user. |
|
|
326 |
# for msg in current_state.values["messages"][:-1][::-1]: |
|
|
327 |
for msg in msgs_to_consider[::-1]: |
|
|
328 |
if isinstance(msg, HumanMessage): |
|
|
329 |
break |
|
|
330 |
if isinstance(msg, AIMessage) and msg.content != "" and msg.name == "supervisor" and last_msg_is_human is False: |
|
|
331 |
continue |
|
|
332 |
# Run the following code if the message is from the agent |
|
|
333 |
if isinstance(msg, AIMessage) and msg.content != "": |
|
|
334 |
assistant_content.append(msg.content) |
|
|
335 |
continue |
|
|
336 |
# Reverse the order |
|
|
337 |
assistant_content = assistant_content[::-1] |
|
|
338 |
# Join the messages |
|
|
339 |
assistant_content = "\n".join(assistant_content) |
|
|
340 |
return assistant_content |
|
|
341 |
|
|
|
342 |
|
|
|
343 |
def get_response(agent, graphs_visuals, app, st, prompt): |
|
|
344 |
# Create config for the agent |
|
|
345 |
config = {"configurable": {"thread_id": st.session_state.unique_id}} |
|
|
346 |
# Update the agent state with the selected LLM model |
|
|
347 |
current_state = app.get_state(config) |
|
|
348 |
# app.update_state( |
|
|
349 |
# config, |
|
|
350 |
# {"sbml_file_path": [st.session_state.sbml_file_path]} |
|
|
351 |
# ) |
|
|
352 |
app.update_state( |
|
|
353 |
config, {"llm_model": get_base_chat_model(st.session_state.llm_model)} |
|
|
354 |
) |
|
|
355 |
# app.update_state( |
|
|
356 |
# config, |
|
|
357 |
# {"text_embedding_model": get_text_embedding_model( |
|
|
358 |
# st.session_state.text_embedding_model), |
|
|
359 |
# "embedding_model": get_text_embedding_model( |
|
|
360 |
# st.session_state.text_embedding_model), |
|
|
361 |
# "uploaded_files": st.session_state.uploaded_files, |
|
|
362 |
# "topk_nodes": st.session_state.topk_nodes, |
|
|
363 |
# "topk_edges": st.session_state.topk_edges, |
|
|
364 |
# "dic_source_graph": [ |
|
|
365 |
# { |
|
|
366 |
# "name": st.session_state.config["kg_name"], |
|
|
367 |
# "kg_pyg_path": st.session_state.config["kg_pyg_path"], |
|
|
368 |
# "kg_text_path": st.session_state.config["kg_text_path"], |
|
|
369 |
# } |
|
|
370 |
# ]} |
|
|
371 |
# ) |
|
|
372 |
if agent == "T2AA4P": |
|
|
373 |
app.update_state(config, update_state_t2b(st) | update_state_t2kg(st)) |
|
|
374 |
elif agent == "T2B": |
|
|
375 |
app.update_state(config, update_state_t2b(st)) |
|
|
376 |
elif agent == "T2KG": |
|
|
377 |
app.update_state(config, update_state_t2kg(st)) |
|
|
378 |
|
|
|
379 |
ERROR_FLAG = False |
|
|
380 |
with collect_runs() as cb: |
|
|
381 |
# Add Langsmith tracer |
|
|
382 |
tracer = LangChainTracer(project_name=st.session_state.project_name) |
|
|
383 |
# Get response from the agent |
|
|
384 |
if current_state.values["llm_model"]._llm_type == "chat-nvidia-ai-playground": |
|
|
385 |
response = app.invoke( |
|
|
386 |
{"messages": [HumanMessage(content=prompt)]}, |
|
|
387 |
config=config | {"callbacks": [tracer]}, |
|
|
388 |
# stream_mode="messages" |
|
|
389 |
) |
|
|
390 |
# Get the current state of the graph |
|
|
391 |
current_state = app.get_state(config) |
|
|
392 |
# Get last response's AI messages |
|
|
393 |
assistant_content = get_ai_messages(current_state) |
|
|
394 |
# st.markdown(response["messages"][-1].content) |
|
|
395 |
st.write(assistant_content) |
|
|
396 |
else: |
|
|
397 |
response = app.stream( |
|
|
398 |
{"messages": [HumanMessage(content=prompt)]}, |
|
|
399 |
config=config | {"callbacks": [tracer]}, |
|
|
400 |
stream_mode="messages", |
|
|
401 |
) |
|
|
402 |
st.write_stream(stream_response(response)) |
|
|
403 |
# print (cb.traced_runs) |
|
|
404 |
# Save the run id and use to save the feedback |
|
|
405 |
st.session_state.run_id = cb.traced_runs[-1].id |
|
|
406 |
|
|
|
407 |
# Get the current state of the graph |
|
|
408 |
current_state = app.get_state(config) |
|
|
409 |
# Get last response's AI messages |
|
|
410 |
assistant_content = get_ai_messages(current_state) |
|
|
411 |
# # Get all the AI msgs in the |
|
|
412 |
# # last response from the state |
|
|
413 |
# assistant_content = [] |
|
|
414 |
# for msg in current_state.values["messages"][::-1]: |
|
|
415 |
# if isinstance(msg, HumanMessage): |
|
|
416 |
# break |
|
|
417 |
# if isinstance(msg, AIMessage) and msg.content != '': |
|
|
418 |
# assistant_content.append(msg.content) |
|
|
419 |
# continue |
|
|
420 |
# # Reverse the order |
|
|
421 |
# assistant_content = assistant_content[::-1] |
|
|
422 |
# # Join the messages |
|
|
423 |
# assistant_content = '\n'.join(assistant_content) |
|
|
424 |
# Add response to chat history |
|
|
425 |
assistant_msg = ChatMessage( |
|
|
426 |
# response["messages"][-1].content, |
|
|
427 |
# current_state.values["messages"][-1].content, |
|
|
428 |
assistant_content, |
|
|
429 |
role="assistant", |
|
|
430 |
) |
|
|
431 |
st.session_state.messages.append({"type": "message", "content": assistant_msg}) |
|
|
432 |
# # Display the response in the chat |
|
|
433 |
# st.markdown(response["messages"][-1].content) |
|
|
434 |
st.empty() |
|
|
435 |
# Get the current state of the graph |
|
|
436 |
current_state = app.get_state(config) |
|
|
437 |
# Get the messages from the current state |
|
|
438 |
# and reverse the order |
|
|
439 |
reversed_messages = current_state.values["messages"][::-1] |
|
|
440 |
# Loop through the reversed messages until a |
|
|
441 |
# HumanMessage is found i.e. the last message |
|
|
442 |
# from the user. This is to display the results |
|
|
443 |
# of the tool calls made by the agent since the |
|
|
444 |
# last message from the user. |
|
|
445 |
for msg in reversed_messages: |
|
|
446 |
# print (msg) |
|
|
447 |
# Break the loop if the message is a HumanMessage |
|
|
448 |
# i.e. the last message from the user |
|
|
449 |
if isinstance(msg, HumanMessage): |
|
|
450 |
break |
|
|
451 |
# Skip the message if it is an AIMessage |
|
|
452 |
# i.e. a message from the agent. An agent |
|
|
453 |
# may make multiple tool calls before the |
|
|
454 |
# final response to the user. |
|
|
455 |
if isinstance(msg, AIMessage): |
|
|
456 |
# print ('AIMessage', msg) |
|
|
457 |
continue |
|
|
458 |
# Work on the message if it is a ToolMessage |
|
|
459 |
# These may contain additional visuals that |
|
|
460 |
# need to be displayed to the user. |
|
|
461 |
# print("ToolMessage", msg) |
|
|
462 |
# Skip the Tool message if it is an error message |
|
|
463 |
if msg.status == "error": |
|
|
464 |
continue |
|
|
465 |
# Create a unique message id to identify the tool call |
|
|
466 |
# msg.name is the name of the tool |
|
|
467 |
# msg.tool_call_id is the unique id of the tool call |
|
|
468 |
# st.session_state.run_id is the unique id of the run |
|
|
469 |
uniq_msg_id = ( |
|
|
470 |
msg.name + "_" + msg.tool_call_id + "_" + str(st.session_state.run_id) |
|
|
471 |
) |
|
|
472 |
print(uniq_msg_id) |
|
|
473 |
if msg.name in ["simulate_model", "custom_plotter"]: |
|
|
474 |
if msg.name == "simulate_model": |
|
|
475 |
print( |
|
|
476 |
"-", |
|
|
477 |
len(current_state.values["dic_simulated_data"]), |
|
|
478 |
"simulate_model", |
|
|
479 |
) |
|
|
480 |
# Convert the simulated data to a single dictionary |
|
|
481 |
dic_simulated_data = {} |
|
|
482 |
for data in current_state.values["dic_simulated_data"]: |
|
|
483 |
for key in data: |
|
|
484 |
if key not in dic_simulated_data: |
|
|
485 |
dic_simulated_data[key] = [] |
|
|
486 |
dic_simulated_data[key] += [data[key]] |
|
|
487 |
# Create a pandas dataframe from the dictionary |
|
|
488 |
df_simulated_data = pd.DataFrame.from_dict(dic_simulated_data) |
|
|
489 |
# Get the simulated data for the current tool call |
|
|
490 |
df_simulated = pd.DataFrame( |
|
|
491 |
df_simulated_data[ |
|
|
492 |
df_simulated_data["tool_call_id"] == msg.tool_call_id |
|
|
493 |
]["data"].iloc[0] |
|
|
494 |
) |
|
|
495 |
df_selected = df_simulated |
|
|
496 |
elif msg.name == "custom_plotter": |
|
|
497 |
if msg.artifact: |
|
|
498 |
df_selected = pd.DataFrame.from_dict(msg.artifact["dic_data"]) |
|
|
499 |
# print (df_selected) |
|
|
500 |
else: |
|
|
501 |
continue |
|
|
502 |
# Display the talbe and plotly chart |
|
|
503 |
render_table_plotly( |
|
|
504 |
uniq_msg_id, |
|
|
505 |
msg.content, |
|
|
506 |
df_selected, |
|
|
507 |
x_axis_label=msg.artifact["x_axis_label"], |
|
|
508 |
y_axis_label=msg.artifact["y_axis_label"], |
|
|
509 |
) |
|
|
510 |
elif msg.name == "steady_state": |
|
|
511 |
if not msg.artifact: |
|
|
512 |
continue |
|
|
513 |
# Create a pandas dataframe from the dictionary |
|
|
514 |
df_selected = pd.DataFrame.from_dict(msg.artifact["dic_data"]) |
|
|
515 |
# Make column 'species_name' the index |
|
|
516 |
df_selected.set_index("species_name", inplace=True) |
|
|
517 |
# Display the toggle button to suppress the table |
|
|
518 |
render_toggle( |
|
|
519 |
key="toggle_table_" + uniq_msg_id, |
|
|
520 |
toggle_text="Show Table", |
|
|
521 |
toggle_state=True, |
|
|
522 |
save_toggle=True, |
|
|
523 |
) |
|
|
524 |
# Display the table |
|
|
525 |
render_table(df_selected, key="dataframe_" + uniq_msg_id, save_table=True) |
|
|
526 |
elif msg.name == "search_models": |
|
|
527 |
if not msg.artifact: |
|
|
528 |
continue |
|
|
529 |
# Create a pandas dataframe from the dictionary |
|
|
530 |
df_selected = pd.DataFrame.from_dict(msg.artifact["dic_data"]) |
|
|
531 |
# Pick selected columns |
|
|
532 |
df_selected = df_selected[["url", "name", "format", "submissionDate"]] |
|
|
533 |
# Display the toggle button to suppress the table |
|
|
534 |
render_toggle( |
|
|
535 |
key="toggle_table_" + uniq_msg_id, |
|
|
536 |
toggle_text="Show Table", |
|
|
537 |
toggle_state=True, |
|
|
538 |
save_toggle=True, |
|
|
539 |
) |
|
|
540 |
# Display the table |
|
|
541 |
st.dataframe( |
|
|
542 |
df_selected, |
|
|
543 |
use_container_width=True, |
|
|
544 |
key="dataframe_" + uniq_msg_id, |
|
|
545 |
hide_index=True, |
|
|
546 |
column_config={ |
|
|
547 |
"url": st.column_config.LinkColumn( |
|
|
548 |
label="ID", |
|
|
549 |
help="Click to open the link associated with the Id", |
|
|
550 |
validate=r"^http://.*$", # Ensure the link is valid |
|
|
551 |
display_text=r"^https://www.ebi.ac.uk/biomodels/(.*?)$", |
|
|
552 |
), |
|
|
553 |
"name": st.column_config.TextColumn("Name"), |
|
|
554 |
"format": st.column_config.TextColumn("Format"), |
|
|
555 |
"submissionDate": st.column_config.TextColumn("Submission Date"), |
|
|
556 |
}, |
|
|
557 |
) |
|
|
558 |
# Add data to the chat history |
|
|
559 |
st.session_state.messages.append( |
|
|
560 |
{ |
|
|
561 |
"type": "dataframe", |
|
|
562 |
"content": df_selected, |
|
|
563 |
"key": "dataframe_" + uniq_msg_id, |
|
|
564 |
"tool_name": msg.name, |
|
|
565 |
} |
|
|
566 |
) |
|
|
567 |
|
|
|
568 |
elif msg.name == "parameter_scan": |
|
|
569 |
# Convert the scanned data to a single dictionary |
|
|
570 |
dic_scanned_data = {} |
|
|
571 |
for data in current_state.values["dic_scanned_data"]: |
|
|
572 |
for key in data: |
|
|
573 |
if key not in dic_scanned_data: |
|
|
574 |
dic_scanned_data[key] = [] |
|
|
575 |
dic_scanned_data[key] += [data[key]] |
|
|
576 |
# Create a pandas dataframe from the dictionary |
|
|
577 |
df_scanned_data = pd.DataFrame.from_dict(dic_scanned_data) |
|
|
578 |
# Get the scanned data for the current tool call |
|
|
579 |
df_scanned_current_tool_call = pd.DataFrame( |
|
|
580 |
df_scanned_data[df_scanned_data["tool_call_id"] == msg.tool_call_id] |
|
|
581 |
) |
|
|
582 |
# df_scanned_current_tool_call.drop_duplicates() |
|
|
583 |
# print (df_scanned_current_tool_call) |
|
|
584 |
for count in range(0, len(df_scanned_current_tool_call.index)): |
|
|
585 |
# Get the scanned data for the current tool call |
|
|
586 |
df_selected = pd.DataFrame( |
|
|
587 |
df_scanned_data[ |
|
|
588 |
df_scanned_data["tool_call_id"] == msg.tool_call_id |
|
|
589 |
]["data"].iloc[count] |
|
|
590 |
) |
|
|
591 |
# Display the toggle button to suppress the table |
|
|
592 |
render_table_plotly( |
|
|
593 |
uniq_msg_id + "_" + str(count), |
|
|
594 |
df_scanned_current_tool_call["name"].iloc[count], |
|
|
595 |
df_selected, |
|
|
596 |
x_axis_label=msg.artifact["x_axis_label"], |
|
|
597 |
y_axis_label=msg.artifact["y_axis_label"], |
|
|
598 |
) |
|
|
599 |
elif msg.name in ["get_annotation"]: |
|
|
600 |
if not msg.artifact: |
|
|
601 |
continue |
|
|
602 |
# Convert the annotated data to a single dictionary |
|
|
603 |
# print ('-', len(current_state.values["dic_annotations_data"])) |
|
|
604 |
dic_annotations_data = {} |
|
|
605 |
for data in current_state.values["dic_annotations_data"]: |
|
|
606 |
# print (data) |
|
|
607 |
for key in data: |
|
|
608 |
if key not in dic_annotations_data: |
|
|
609 |
dic_annotations_data[key] = [] |
|
|
610 |
dic_annotations_data[key] += [data[key]] |
|
|
611 |
df_annotations_data = pd.DataFrame.from_dict(dic_annotations_data) |
|
|
612 |
# Get the annotated data for the current tool call |
|
|
613 |
df_selected = pd.DataFrame( |
|
|
614 |
df_annotations_data[ |
|
|
615 |
df_annotations_data["tool_call_id"] == msg.tool_call_id |
|
|
616 |
]["data"].iloc[0] |
|
|
617 |
) |
|
|
618 |
# print (df_selected) |
|
|
619 |
df_selected["Id"] = df_selected.apply( |
|
|
620 |
lambda row: row["Link"], axis=1 # Ensure "Id" has the correct links |
|
|
621 |
) |
|
|
622 |
df_selected = df_selected.drop(columns=["Link"]) |
|
|
623 |
# Directly use the "Link" column for the "Id" column |
|
|
624 |
render_toggle( |
|
|
625 |
key="toggle_table_" + uniq_msg_id, |
|
|
626 |
toggle_text="Show Table", |
|
|
627 |
toggle_state=True, |
|
|
628 |
save_toggle=True, |
|
|
629 |
) |
|
|
630 |
st.dataframe( |
|
|
631 |
df_selected, |
|
|
632 |
use_container_width=True, |
|
|
633 |
key="dataframe_" + uniq_msg_id, |
|
|
634 |
hide_index=True, |
|
|
635 |
column_config={ |
|
|
636 |
"Id": st.column_config.LinkColumn( |
|
|
637 |
label="Id", |
|
|
638 |
help="Click to open the link associated with the Id", |
|
|
639 |
validate=r"^http://.*$", # Ensure the link is valid |
|
|
640 |
display_text=r"^http://identifiers\.org/(.*?)$", |
|
|
641 |
), |
|
|
642 |
"Species Name": st.column_config.TextColumn("Species Name"), |
|
|
643 |
"Description": st.column_config.TextColumn("Description"), |
|
|
644 |
"Database": st.column_config.TextColumn("Database"), |
|
|
645 |
}, |
|
|
646 |
) |
|
|
647 |
# Add data to the chat history |
|
|
648 |
st.session_state.messages.append( |
|
|
649 |
{ |
|
|
650 |
"type": "dataframe", |
|
|
651 |
"content": df_selected, |
|
|
652 |
"key": "dataframe_" + uniq_msg_id, |
|
|
653 |
"tool_name": msg.name, |
|
|
654 |
} |
|
|
655 |
) |
|
|
656 |
elif msg.name in ["subgraph_extraction"]: |
|
|
657 |
print( |
|
|
658 |
"-", |
|
|
659 |
len(current_state.values["dic_extracted_graph"]), |
|
|
660 |
"subgraph_extraction", |
|
|
661 |
) |
|
|
662 |
# Add the graph into the visuals list |
|
|
663 |
latest_graph = current_state.values["dic_extracted_graph"][-1] |
|
|
664 |
if current_state.values["dic_extracted_graph"]: |
|
|
665 |
graphs_visuals.append( |
|
|
666 |
{ |
|
|
667 |
"content": latest_graph["graph_dict"], |
|
|
668 |
"key": "subgraph_" + uniq_msg_id, |
|
|
669 |
} |
|
|
670 |
) |
|
|
671 |
elif msg.name in ["display_results"]: |
|
|
672 |
# This is a tool of T2S agent's sub-agent S2 |
|
|
673 |
dic_papers = msg.artifact |
|
|
674 |
if not dic_papers: |
|
|
675 |
continue |
|
|
676 |
df_papers = pd.DataFrame.from_dict(dic_papers, orient="index") |
|
|
677 |
# Add index as a column "key" |
|
|
678 |
df_papers["Key"] = df_papers.index |
|
|
679 |
# Drop index |
|
|
680 |
df_papers.reset_index(drop=True, inplace=True) |
|
|
681 |
# Drop colum abstract |
|
|
682 |
# Define the columns to drop |
|
|
683 |
columns_to_drop = [ |
|
|
684 |
"Abstract", |
|
|
685 |
"Key", |
|
|
686 |
"arxiv_id", |
|
|
687 |
"semantic_scholar_paper_id", |
|
|
688 |
] |
|
|
689 |
|
|
|
690 |
# Check if columns exist before dropping |
|
|
691 |
existing_columns = [ |
|
|
692 |
col for col in columns_to_drop if col in df_papers.columns |
|
|
693 |
] |
|
|
694 |
|
|
|
695 |
if existing_columns: |
|
|
696 |
df_papers.drop(columns=existing_columns, inplace=True) |
|
|
697 |
|
|
|
698 |
if "Year" in df_papers.columns: |
|
|
699 |
df_papers["Year"] = df_papers["Year"].apply( |
|
|
700 |
lambda x: ( |
|
|
701 |
str(int(x)) if pd.notna(x) and str(x).isdigit() else None |
|
|
702 |
) |
|
|
703 |
) |
|
|
704 |
|
|
|
705 |
if "Date" in df_papers.columns: |
|
|
706 |
df_papers["Date"] = df_papers["Date"].apply( |
|
|
707 |
lambda x: ( |
|
|
708 |
pd.to_datetime(x, errors="coerce").strftime("%Y-%m-%d") |
|
|
709 |
if pd.notna(pd.to_datetime(x, errors="coerce")) |
|
|
710 |
else None |
|
|
711 |
) |
|
|
712 |
) |
|
|
713 |
|
|
|
714 |
st.dataframe( |
|
|
715 |
df_papers, |
|
|
716 |
hide_index=True, |
|
|
717 |
column_config={ |
|
|
718 |
"URL": st.column_config.LinkColumn( |
|
|
719 |
display_text="Open", |
|
|
720 |
), |
|
|
721 |
}, |
|
|
722 |
) |
|
|
723 |
# Add data to the chat history |
|
|
724 |
st.session_state.messages.append( |
|
|
725 |
{ |
|
|
726 |
"type": "dataframe", |
|
|
727 |
"content": df_papers, |
|
|
728 |
"key": "dataframe_" + uniq_msg_id, |
|
|
729 |
"tool_name": msg.name, |
|
|
730 |
} |
|
|
731 |
) |
|
|
732 |
st.empty() |
|
|
733 |
|
|
|
734 |
|
|
|
735 |
def render_graph(graph_dict: dict, key: str, save_graph: bool = False): |
|
|
736 |
""" |
|
|
737 |
Function to render the graph in the chat. |
|
|
738 |
|
|
|
739 |
Args: |
|
|
740 |
graph_dict: The graph dictionary |
|
|
741 |
key: The key for the graph |
|
|
742 |
save_graph: Whether to save the graph in the chat history |
|
|
743 |
""" |
|
|
744 |
# Create a directed graph |
|
|
745 |
graph = nx.DiGraph() |
|
|
746 |
|
|
|
747 |
# Add nodes with attributes |
|
|
748 |
for node, attrs in graph_dict["nodes"]: |
|
|
749 |
graph.add_node(node, **attrs) |
|
|
750 |
|
|
|
751 |
# Add edges with attributes |
|
|
752 |
for source, target, attrs in graph_dict["edges"]: |
|
|
753 |
graph.add_edge(source, target, **attrs) |
|
|
754 |
|
|
|
755 |
# Render the graph |
|
|
756 |
fig = gravis.d3( |
|
|
757 |
graph, |
|
|
758 |
node_size_factor=3.0, |
|
|
759 |
show_edge_label=True, |
|
|
760 |
edge_label_data_source="label", |
|
|
761 |
edge_curvature=0.25, |
|
|
762 |
zoom_factor=1.0, |
|
|
763 |
many_body_force_strength=-500, |
|
|
764 |
many_body_force_theta=0.3, |
|
|
765 |
node_hover_neighborhood=True, |
|
|
766 |
# layout_algorithm_active=True, |
|
|
767 |
) |
|
|
768 |
components.html(fig.to_html(), height=475) |
|
|
769 |
|
|
|
770 |
if save_graph: |
|
|
771 |
# Add data to the chat history |
|
|
772 |
st.session_state.messages.append( |
|
|
773 |
{ |
|
|
774 |
"type": "graph", |
|
|
775 |
"content": graph_dict, |
|
|
776 |
"key": key, |
|
|
777 |
} |
|
|
778 |
) |
|
|
779 |
|
|
|
780 |
|
|
|
781 |
def get_text_embedding_model(model_name) -> Embeddings: |
|
|
782 |
""" |
|
|
783 |
Function to get the text embedding model. |
|
|
784 |
|
|
|
785 |
Args: |
|
|
786 |
model_name: str: The name of the model |
|
|
787 |
|
|
|
788 |
Returns: |
|
|
789 |
Embeddings: The text embedding model |
|
|
790 |
""" |
|
|
791 |
dic_text_embedding_models = { |
|
|
792 |
"NVIDIA/llama-3.2-nv-embedqa-1b-v2": "nvidia/llama-3.2-nv-embedqa-1b-v2", |
|
|
793 |
"OpenAI/text-embedding-ada-002": "text-embedding-ada-002", |
|
|
794 |
} |
|
|
795 |
if model_name.startswith("NVIDIA"): |
|
|
796 |
return NVIDIAEmbeddings(model=dic_text_embedding_models[model_name]) |
|
|
797 |
return OpenAIEmbeddings(model=dic_text_embedding_models[model_name]) |
|
|
798 |
|
|
|
799 |
|
|
|
800 |
def get_base_chat_model(model_name) -> BaseChatModel: |
|
|
801 |
""" |
|
|
802 |
Function to get the base chat model. |
|
|
803 |
|
|
|
804 |
Args: |
|
|
805 |
model_name: str: The name of the model |
|
|
806 |
|
|
|
807 |
Returns: |
|
|
808 |
BaseChatModel: The base chat model |
|
|
809 |
""" |
|
|
810 |
dic_llm_models = { |
|
|
811 |
"NVIDIA/llama-3.3-70b-instruct": "meta/llama-3.3-70b-instruct", |
|
|
812 |
"NVIDIA/llama-3.1-405b-instruct": "meta/llama-3.1-405b-instruct", |
|
|
813 |
"NVIDIA/llama-3.1-70b-instruct": "meta/llama-3.1-70b-instruct", |
|
|
814 |
"OpenAI/gpt-4o-mini": "gpt-4o-mini", |
|
|
815 |
} |
|
|
816 |
if model_name.startswith("Llama"): |
|
|
817 |
return ChatOllama(model=dic_llm_models[model_name], temperature=0) |
|
|
818 |
elif model_name.startswith("NVIDIA"): |
|
|
819 |
return ChatNVIDIA(model=dic_llm_models[model_name], temperature=0) |
|
|
820 |
return ChatOpenAI(model=dic_llm_models[model_name], temperature=0) |
|
|
821 |
|
|
|
822 |
|
|
|
823 |
@st.dialog("Warning ⚠️") |
|
|
824 |
def update_llm_model(): |
|
|
825 |
""" |
|
|
826 |
Function to update the LLM model. |
|
|
827 |
""" |
|
|
828 |
llm_model = st.session_state.llm_model |
|
|
829 |
st.warning( |
|
|
830 |
f"Clicking 'Continue' will reset all agents, \ |
|
|
831 |
set the selected LLM to {llm_model}. \ |
|
|
832 |
This action will reset the entire app, \ |
|
|
833 |
and agents will lose access to the \ |
|
|
834 |
conversation history. Are you sure \ |
|
|
835 |
you want to proceed?" |
|
|
836 |
) |
|
|
837 |
if st.button("Continue"): |
|
|
838 |
# st.session_state.vote = {"item": item, "reason": reason} |
|
|
839 |
# st.rerun() |
|
|
840 |
# Delete all the items in Session state |
|
|
841 |
for key in st.session_state.keys(): |
|
|
842 |
if key in ["messages", "app"]: |
|
|
843 |
del st.session_state[key] |
|
|
844 |
st.rerun() |
|
|
845 |
|
|
|
846 |
|
|
|
847 |
def update_text_embedding_model(app): |
|
|
848 |
""" |
|
|
849 |
Function to update the text embedding model. |
|
|
850 |
|
|
|
851 |
Args: |
|
|
852 |
app: The LangGraph app |
|
|
853 |
""" |
|
|
854 |
config = {"configurable": {"thread_id": st.session_state.unique_id}} |
|
|
855 |
app.update_state( |
|
|
856 |
config, |
|
|
857 |
{ |
|
|
858 |
"text_embedding_model": get_text_embedding_model( |
|
|
859 |
st.session_state.text_embedding_model |
|
|
860 |
) |
|
|
861 |
}, |
|
|
862 |
) |
|
|
863 |
|
|
|
864 |
|
|
|
865 |
@st.dialog("Get started with Talk2Biomodels 🚀") |
|
|
866 |
def help_button(): |
|
|
867 |
""" |
|
|
868 |
Function to display the help dialog. |
|
|
869 |
""" |
|
|
870 |
st.markdown( |
|
|
871 |
"""I am an AI agent designed to assist you with biological |
|
|
872 |
modeling and simulations. I can assist with tasks such as: |
|
|
873 |
1. Search specific models in the BioModels database. |
|
|
874 |
|
|
|
875 |
``` |
|
|
876 |
Search models on Crohns disease |
|
|
877 |
``` |
|
|
878 |
|
|
|
879 |
2. Extract information about models, including species, parameters, units, |
|
|
880 |
name and descriptions. |
|
|
881 |
|
|
|
882 |
``` |
|
|
883 |
Briefly describe model 537 and |
|
|
884 |
its parameters related to drug dosage |
|
|
885 |
``` |
|
|
886 |
|
|
|
887 |
3. Simulate models: |
|
|
888 |
- Run simulations of models to see how they behave over time. |
|
|
889 |
- Set the duration and the interval. |
|
|
890 |
- Specify which species/parameters you want to include and their starting concentrations/values. |
|
|
891 |
- Include recurring events. |
|
|
892 |
|
|
|
893 |
``` |
|
|
894 |
Simulate the model 537 for 2016 hours and |
|
|
895 |
intervals 300 with an initial value |
|
|
896 |
of `DoseQ2W` set to 300 and `Dose` set to 0. |
|
|
897 |
``` |
|
|
898 |
|
|
|
899 |
4. Answer questions about simulation results. |
|
|
900 |
|
|
|
901 |
``` |
|
|
902 |
What is the concentration of species IL6 in serum |
|
|
903 |
at the end of simulation? |
|
|
904 |
``` |
|
|
905 |
|
|
|
906 |
5. Create custom plots to visualize the simulation results. |
|
|
907 |
|
|
|
908 |
``` |
|
|
909 |
Plot the concentration of all |
|
|
910 |
the interleukins over time. |
|
|
911 |
``` |
|
|
912 |
|
|
|
913 |
6. Bring a model to a steady state and determine the concentration of a species at the steady state. |
|
|
914 |
|
|
|
915 |
``` |
|
|
916 |
Bring BioModel 27 to a steady state, |
|
|
917 |
and then determine the Mpp concentration |
|
|
918 |
at the steady state. |
|
|
919 |
``` |
|
|
920 |
|
|
|
921 |
7. Perform parameter scans to determine the effect of changing parameters on the model behavior. |
|
|
922 |
|
|
|
923 |
``` |
|
|
924 |
How does the value of Pyruvate change in |
|
|
925 |
model 64 if the concentration of Extracellular Glucose |
|
|
926 |
is changed from 10 to 100 with a step size of 10? |
|
|
927 |
The simulation should run for 5 time units with an |
|
|
928 |
interval of 10. |
|
|
929 |
``` |
|
|
930 |
|
|
|
931 |
8. Check out the [Use Cases](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/cases/Case_1/) |
|
|
932 |
for more examples, and the [FAQs](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/faq/) |
|
|
933 |
for common questions. |
|
|
934 |
|
|
|
935 |
9. Provide feedback to the developers by clicking on the feedback button. |
|
|
936 |
|
|
|
937 |
""" |
|
|
938 |
) |
|
|
939 |
|
|
|
940 |
|
|
|
941 |
def apply_css(): |
|
|
942 |
""" |
|
|
943 |
Function to apply custom CSS for streamlit app. |
|
|
944 |
""" |
|
|
945 |
# Styling using CSS |
|
|
946 |
st.markdown( |
|
|
947 |
"""<style> |
|
|
948 |
.stFileUploaderFile { display: none;} |
|
|
949 |
#stFileUploaderPagination { display: none;} |
|
|
950 |
.st-emotion-cache-wbtvu4 { display: none;} |
|
|
951 |
</style> |
|
|
952 |
""", |
|
|
953 |
unsafe_allow_html=True, |
|
|
954 |
) |
|
|
955 |
|
|
|
956 |
|
|
|
957 |
def get_file_type_icon(file_type: str) -> str: |
|
|
958 |
""" |
|
|
959 |
Function to get the icon for the file type. |
|
|
960 |
|
|
|
961 |
Args: |
|
|
962 |
file_type (str): The file type. |
|
|
963 |
|
|
|
964 |
Returns: |
|
|
965 |
str: The icon for the file type. |
|
|
966 |
""" |
|
|
967 |
return {"drug_data": "💊", "endotype": "🧬", "sbml_file": "📜"}.get(file_type) |
|
|
968 |
|
|
|
969 |
|
|
|
970 |
@st.fragment |
|
|
971 |
def get_t2b_uploaded_files(app): |
|
|
972 |
""" |
|
|
973 |
Upload files for T2B agent. |
|
|
974 |
""" |
|
|
975 |
# Upload the XML/SBML file |
|
|
976 |
uploaded_sbml_file = st.file_uploader( |
|
|
977 |
"Upload an XML/SBML file", |
|
|
978 |
accept_multiple_files=False, |
|
|
979 |
type=["xml", "sbml"], |
|
|
980 |
help="Upload a QSP as an XML/SBML file", |
|
|
981 |
) |
|
|
982 |
|
|
|
983 |
# Upload the article |
|
|
984 |
article = st.file_uploader( |
|
|
985 |
"Upload an article", |
|
|
986 |
help="Upload a PDF article to ask questions.", |
|
|
987 |
accept_multiple_files=False, |
|
|
988 |
type=["pdf"], |
|
|
989 |
key="article", |
|
|
990 |
) |
|
|
991 |
# Update the agent state with the uploaded article |
|
|
992 |
if article: |
|
|
993 |
# print (article.name) |
|
|
994 |
with tempfile.NamedTemporaryFile(delete=False) as f: |
|
|
995 |
f.write(article.read()) |
|
|
996 |
# Create config for the agent |
|
|
997 |
config = {"configurable": {"thread_id": st.session_state.unique_id}} |
|
|
998 |
# Update the agent state with the selected LLM model |
|
|
999 |
app.update_state(config, {"pdf_file_name": f.name}) |
|
|
1000 |
# Return the uploaded file |
|
|
1001 |
return uploaded_sbml_file |
|
|
1002 |
|
|
|
1003 |
|
|
|
1004 |
@st.fragment |
|
|
1005 |
def get_uploaded_files(cfg: hydra.core.config_store.ConfigStore) -> None: |
|
|
1006 |
""" |
|
|
1007 |
Upload files to a directory set in cfg.upload_data_dir, and display them in the UI. |
|
|
1008 |
|
|
|
1009 |
Args: |
|
|
1010 |
cfg: The configuration object. |
|
|
1011 |
""" |
|
|
1012 |
# sbml_file = st.file_uploader("📜 Upload SBML file", |
|
|
1013 |
# accept_multiple_files=False, |
|
|
1014 |
# help='Upload an ODE model in SBML format.', |
|
|
1015 |
# type=["xml", "sbml"], |
|
|
1016 |
# key=f"uploader_sbml_file_{st.session_state.sbml_key}") |
|
|
1017 |
|
|
|
1018 |
data_package_files = st.file_uploader( |
|
|
1019 |
"💊 Upload pre-clinical drug data", |
|
|
1020 |
help="Free-form text. Must contain atleast drug targets and kinetic parameters", |
|
|
1021 |
accept_multiple_files=True, |
|
|
1022 |
type=cfg.data_package_allowed_file_types, |
|
|
1023 |
key=f"uploader_{st.session_state.data_package_key}", |
|
|
1024 |
) |
|
|
1025 |
|
|
|
1026 |
endotype_files = st.file_uploader( |
|
|
1027 |
"🧬 Upload endotype data", |
|
|
1028 |
help="Free-form text. List of differentially expressed genes", |
|
|
1029 |
accept_multiple_files=True, |
|
|
1030 |
type=cfg.endotype_allowed_file_types, |
|
|
1031 |
key=f"uploader_endotype_{st.session_state.endotype_key}", |
|
|
1032 |
) |
|
|
1033 |
|
|
|
1034 |
# Merge the uploaded files |
|
|
1035 |
uploaded_files = data_package_files.copy() |
|
|
1036 |
if endotype_files: |
|
|
1037 |
uploaded_files += endotype_files.copy() |
|
|
1038 |
# if sbml_file: |
|
|
1039 |
# uploaded_files += [sbml_file] |
|
|
1040 |
|
|
|
1041 |
with st.spinner("Storing uploaded file(s) ..."): |
|
|
1042 |
# for uploaded_file in data_package_files: |
|
|
1043 |
for uploaded_file in uploaded_files: |
|
|
1044 |
if uploaded_file.name not in [ |
|
|
1045 |
uf["file_name"] for uf in st.session_state.uploaded_files |
|
|
1046 |
]: |
|
|
1047 |
current_timestamp = datetime.datetime.now().strftime( |
|
|
1048 |
"%Y-%m-%d %H:%M:%S" |
|
|
1049 |
) |
|
|
1050 |
uploaded_file.file_name = uploaded_file.name |
|
|
1051 |
uploaded_file.file_path = ( |
|
|
1052 |
f"{cfg.upload_data_dir}/{uploaded_file.file_name}" |
|
|
1053 |
) |
|
|
1054 |
uploaded_file.current_user = st.session_state.current_user |
|
|
1055 |
uploaded_file.timestamp = current_timestamp |
|
|
1056 |
if uploaded_file.name in [uf.name for uf in data_package_files]: |
|
|
1057 |
uploaded_file.file_type = "drug_data" |
|
|
1058 |
elif uploaded_file.name in [uf.name for uf in endotype_files]: |
|
|
1059 |
uploaded_file.file_type = "endotype" |
|
|
1060 |
else: |
|
|
1061 |
uploaded_file.file_type = "sbml_file" |
|
|
1062 |
st.session_state.uploaded_files.append( |
|
|
1063 |
{ |
|
|
1064 |
"file_name": uploaded_file.file_name, |
|
|
1065 |
"file_path": uploaded_file.file_path, |
|
|
1066 |
"file_type": uploaded_file.file_type, |
|
|
1067 |
"uploaded_by": uploaded_file.current_user, |
|
|
1068 |
"uploaded_timestamp": uploaded_file.timestamp, |
|
|
1069 |
} |
|
|
1070 |
) |
|
|
1071 |
with open( |
|
|
1072 |
os.path.join(cfg.upload_data_dir, uploaded_file.file_name), "wb" |
|
|
1073 |
) as f: |
|
|
1074 |
f.write(uploaded_file.getbuffer()) |
|
|
1075 |
uploaded_file = None |
|
|
1076 |
|
|
|
1077 |
# Display uploaded files and provide a remove button |
|
|
1078 |
for uploaded_file in st.session_state.uploaded_files: |
|
|
1079 |
col1, col2 = st.columns([4, 1]) |
|
|
1080 |
with col1: |
|
|
1081 |
st.write( |
|
|
1082 |
get_file_type_icon(uploaded_file["file_type"]) |
|
|
1083 |
+ uploaded_file["file_name"] |
|
|
1084 |
) |
|
|
1085 |
with col2: |
|
|
1086 |
if st.button("🗑️", key=uploaded_file["file_name"]): |
|
|
1087 |
with st.spinner("Removing uploaded file ..."): |
|
|
1088 |
if os.path.isfile( |
|
|
1089 |
f"{cfg.upload_data_dir}/{uploaded_file['file_name']}" |
|
|
1090 |
): |
|
|
1091 |
os.remove(f"{cfg.upload_data_dir}/{uploaded_file['file_name']}") |
|
|
1092 |
st.session_state.uploaded_files.remove(uploaded_file) |
|
|
1093 |
st.cache_data.clear() |
|
|
1094 |
st.session_state.data_package_key += 1 |
|
|
1095 |
st.session_state.endotype_key += 1 |
|
|
1096 |
st.rerun(scope="fragment") |