a b/app/frontend/utils/streamlit_utils.py
1
#!/usr/bin/env python3
2
3
"""
4
Utils for Streamlit.
5
"""
6
7
import os
8
import datetime
9
import hydra
10
import tempfile
11
import streamlit as st
12
import streamlit.components.v1 as components
13
import pandas as pd
14
import plotly.express as px
15
from langsmith import Client
16
from langchain_ollama import ChatOllama
17
from langchain_openai import ChatOpenAI
18
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
19
from langchain_openai.embeddings import OpenAIEmbeddings
20
from langchain_core.language_models import BaseChatModel
21
from langchain_core.embeddings import Embeddings
22
from langchain_core.messages import AIMessageChunk, HumanMessage, ChatMessage, AIMessage
23
from langchain_core.tracers.context import collect_runs
24
from langchain.callbacks.tracers import LangChainTracer
25
import networkx as nx
26
import gravis
27
28
29
def submit_feedback(user_response):
30
    """
31
    Function to submit feedback to the developers.
32
33
    Args:
34
        user_response: dict: The user response
35
    """
36
    client = Client()
37
    client.create_feedback(
38
        st.session_state.run_id,
39
        key="feedback",
40
        score=1 if user_response["score"] == "👍" else 0,
41
        comment=user_response["text"],
42
    )
43
    st.info("Your feedback is on its way to the developers. Thank you!", icon="🚀")
44
45
46
def render_table_plotly(
47
    uniq_msg_id, content, df_selected, x_axis_label="Time", y_axis_label="Concentration"
48
):
49
    """
50
    Function to render the table and plotly chart in the chat.
51
52
    Args:
53
        uniq_msg_id: str: The unique message id
54
        msg: dict: The message object
55
        df_selected: pd.DataFrame: The selected dataframe
56
    """
57
    # Display the toggle button to suppress the table
58
    render_toggle(
59
        key="toggle_plotly_" + uniq_msg_id,
60
        toggle_text="Show Plot",
61
        toggle_state=True,
62
        save_toggle=True,
63
    )
64
    # Display the plotly chart
65
    render_plotly(
66
        df_selected,
67
        key="plotly_" + uniq_msg_id,
68
        title=content,
69
        y_axis_label=y_axis_label,
70
        x_axis_label=x_axis_label,
71
        save_chart=True,
72
    )
73
    # Display the toggle button to suppress the table
74
    render_toggle(
75
        key="toggle_table_" + uniq_msg_id,
76
        toggle_text="Show Table",
77
        toggle_state=False,
78
        save_toggle=True,
79
    )
80
    # Display the table
81
    render_table(df_selected, key="dataframe_" + uniq_msg_id, save_table=True)
82
    st.empty()
83
84
85
def render_toggle(
86
    key: str, toggle_text: str, toggle_state: bool, save_toggle: bool = False
87
):
88
    """
89
    Function to render the toggle button to show/hide the table.
90
91
    Args:
92
        key: str: The key for the toggle button
93
        toggle_text: str: The text for the toggle button
94
        toggle_state: bool: The state of the toggle button
95
        save_toggle: bool: Flag to save the toggle button to the chat history
96
    """
97
    st.toggle(toggle_text, toggle_state, help="""Toggle to show/hide data""", key=key)
98
    # print (key)
99
    if save_toggle:
100
        # Add data to the chat history
101
        st.session_state.messages.append(
102
            {
103
                "type": "toggle",
104
                "content": toggle_text,
105
                "toggle_state": toggle_state,
106
                "key": key,
107
            }
108
        )
109
110
111
def render_plotly(
112
    df: pd.DataFrame,
113
    key: str,
114
    title: str,
115
    y_axis_label: str,
116
    x_axis_label: str,
117
    save_chart: bool = False,
118
):
119
    """
120
    Function to visualize the dataframe using Plotly.
121
122
    Args:
123
        df: pd.DataFrame: The input dataframe
124
        key: str: The key for the plotly chart
125
        title: str: The title of the plotly chart
126
        save_chart: bool: Flag to save the chart to the chat history
127
    """
128
    # toggle_state = st.session_state[f'toggle_plotly_{tool_name}_{key.split("_")[-1]}']\
129
    toggle_state = st.session_state[f'toggle_plotly_{key.split("plotly_")[1]}']
130
    if toggle_state:
131
        df_simulation_results = df.melt(
132
            id_vars="Time", var_name="Species", value_name="Concentration"
133
        )
134
        fig = px.line(
135
            df_simulation_results,
136
            x="Time",
137
            y="Concentration",
138
            color="Species",
139
            title=title,
140
            height=500,
141
            width=600,
142
        )
143
        # Set y axis label
144
        fig.update_yaxes(title_text=f"Quantity ({y_axis_label})")
145
        # Set x axis label
146
        fig.update_xaxes(title_text=f"Time ({x_axis_label})")
147
        # Display the plotly chart
148
        st.plotly_chart(fig, use_container_width=True, key=key)
149
    if save_chart:
150
        # Add data to the chat history
151
        st.session_state.messages.append(
152
            {
153
                "type": "plotly",
154
                "content": df,
155
                "key": key,
156
                "title": title,
157
                "y_axis_label": y_axis_label,
158
                "x_axis_label": x_axis_label,
159
                # "tool_name": tool_name
160
            }
161
        )
162
163
164
def render_table(df: pd.DataFrame, key: str, save_table: bool = False):
165
    """
166
    Function to render the table in the chat.
167
168
    Args:
169
        df: pd.DataFrame: The input dataframe
170
        key: str: The key for the table
171
        save_table: bool: Flag to save the table to the chat history
172
    """
173
    # print (st.session_state['toggle_simulate_model_'+key.split("_")[-1]])
174
    # toggle_state = st.session_state[f'toggle_table_{tool_name}_{key.split("_")[-1]}']
175
    toggle_state = st.session_state[f'toggle_table_{key.split("dataframe_")[1]}']
176
    if toggle_state:
177
        st.dataframe(df, use_container_width=True, key=key)
178
    if save_table:
179
        # Add data to the chat history
180
        st.session_state.messages.append(
181
            {
182
                "type": "dataframe",
183
                "content": df,
184
                "key": key,
185
                # "tool_name": tool_name
186
            }
187
        )
188
189
190
def sample_questions():
191
    """
192
    Function to get the sample questions.
193
    """
194
    questions = [
195
        'Search for all biomodels on "Crohns Disease"',
196
        "Briefly describe biomodel 971 and simulate it for 50 days with an interval of 50.",
197
        "Bring biomodel 27 to a steady state, and then "
198
        "determine the Mpp concentration at the steady state.",
199
        "How will the concentration of Mpp change in model 27, "
200
        "if the initial value of MAPKK were to be changed between 1 and 100 in steps of 10?",
201
        "Show annotations of all interleukins in model 537",
202
    ]
203
    return questions
204
205
206
def sample_questions_t2s():
207
    """
208
    Function to get the sample questions for Talk2Scholars.
209
    """
210
    questions = [
211
        'Search articles on "Role of DNA damage response (DDR) in Cancer"',
212
        "Save these articles in my Zotero library under the collection 'Curiosity'",
213
        "Tell me more about the first article in the last search results",
214
        "Download the article 'Attention is All You Need'",
215
        "Describe the methods of the downloaded paper",
216
    ]
217
    return questions
218
219
220
def sample_questions_t2aa4p():
221
    """
222
    Function to get the sample questions for Talk2AIAgents4Pharma.
223
    """
224
    questions = [
225
        'Search for all the biomodels on "Crohns Disease"',
226
        "Briefly describe biomodel 537 and simulate it for 2016 hours with an interval of 100.",
227
        "List the drugs that target Interleukin-6",
228
        "What genes are associated with Crohn's disease?",
229
    ]
230
    return questions
231
232
233
def stream_response(response):
234
    """
235
    Function to stream the response from the agent.
236
237
    Args:
238
        response: dict: The response from the agent
239
    """
240
    agent_responding = False
241
    for chunk in response:
242
        # Stream only the AIMessageChunk
243
        if not isinstance(chunk[0], AIMessageChunk):
244
            continue
245
        # print (chunk[0].content, chunk[1])
246
        # Exclude the tool calls that are not part of the conversation
247
        # if "branch:agent:should_continue:tools" not in chunk[1]["langgraph_triggers"]:
248
        # if chunk[1]["checkpoint_ns"].startswith("supervisor"):
249
        #     continue
250
        if chunk[1]["checkpoint_ns"].startswith("supervisor") is False:
251
            agent_responding = True
252
            if "branch:to:agent" in chunk[1]["langgraph_triggers"]:
253
                if chunk[0].content == "":
254
                    yield "\n"
255
                yield chunk[0].content
256
        else:
257
            # If no agent has responded yet
258
            # and the message is from the supervisor
259
            # then display the message
260
            if agent_responding is False:
261
                if "branch:to:agent" in chunk[1]["langgraph_triggers"]:
262
                    if chunk[0].content == "":
263
                        yield "\n"
264
                    yield chunk[0].content
265
        # if "tools" in chunk[1]["langgraph_triggers"]:
266
        #     agent_responded = True
267
        #     if chunk[0].content == "":
268
        #         yield "\n"
269
        #     yield chunk[0].content
270
        # if agent_responding:
271
        #     continue
272
        # if "branch:to:agent" in chunk[1]["langgraph_triggers"]:
273
        #     if chunk[0].content == "":
274
        #         yield "\n"
275
        #     yield chunk[0].content
276
277
278
def update_state_t2b(st):
279
    dic = {
280
        "sbml_file_path": [st.session_state.sbml_file_path],
281
        "text_embedding_model": get_text_embedding_model(
282
            st.session_state.text_embedding_model
283
        ),
284
    }
285
    return dic
286
287
288
def update_state_t2kg(st):
289
    dic = {
290
        "embedding_model": get_text_embedding_model(
291
            st.session_state.text_embedding_model
292
        ),
293
        "uploaded_files": st.session_state.uploaded_files,
294
        "topk_nodes": st.session_state.topk_nodes,
295
        "topk_edges": st.session_state.topk_edges,
296
        "dic_source_graph": [
297
            {
298
                "name": st.session_state.config["kg_name"],
299
                "kg_pyg_path": st.session_state.config["kg_pyg_path"],
300
                "kg_text_path": st.session_state.config["kg_text_path"],
301
            }
302
        ],
303
    }
304
    return dic
305
306
307
def get_ai_messages(current_state):
308
    last_msg_is_human = False
309
    # If only supervisor answered i.e. no agent was called
310
    if isinstance(current_state.values["messages"][-2], HumanMessage):
311
        # msgs_to_consider = current_state.values["messages"]
312
        last_msg_is_human = True
313
    # else:
314
    #     # If agent answered i.e. ignore the supervisor msg
315
    #     msgs_to_consider = current_state.values["messages"][:-1]
316
    msgs_to_consider = current_state.values["messages"]
317
    # Get all the AI msgs in the
318
    # last response from the state
319
    assistant_content = []
320
    # print ('LEN:', len(current_state.values["messages"][:-1]))
321
    # print (current_state.values["messages"][-2])
322
    # Variable to check if the last message is from the "supervisor"
323
    # Supervisor message exists for agents that have sub-agents
324
    # In such cases, the last message is from the supervisor
325
    # and that is the message to be displayed to the user.
326
    # for msg in current_state.values["messages"][:-1][::-1]:
327
    for msg in msgs_to_consider[::-1]:
328
        if isinstance(msg, HumanMessage):
329
            break
330
        if isinstance(msg, AIMessage) and msg.content != "" and msg.name == "supervisor" and last_msg_is_human is False:
331
            continue
332
        # Run the following code if the message is from the agent
333
        if isinstance(msg, AIMessage) and msg.content != "":
334
            assistant_content.append(msg.content)
335
            continue
336
    # Reverse the order
337
    assistant_content = assistant_content[::-1]
338
    # Join the messages
339
    assistant_content = "\n".join(assistant_content)
340
    return assistant_content
341
342
343
def get_response(agent, graphs_visuals, app, st, prompt):
344
    # Create config for the agent
345
    config = {"configurable": {"thread_id": st.session_state.unique_id}}
346
    # Update the agent state with the selected LLM model
347
    current_state = app.get_state(config)
348
    # app.update_state(
349
    #     config,
350
    #     {"sbml_file_path": [st.session_state.sbml_file_path]}
351
    # )
352
    app.update_state(
353
        config, {"llm_model": get_base_chat_model(st.session_state.llm_model)}
354
    )
355
    # app.update_state(
356
    #     config,
357
    #     {"text_embedding_model": get_text_embedding_model(
358
    #         st.session_state.text_embedding_model),
359
    #     "embedding_model": get_text_embedding_model(
360
    #         st.session_state.text_embedding_model),
361
    #     "uploaded_files": st.session_state.uploaded_files,
362
    #     "topk_nodes": st.session_state.topk_nodes,
363
    #     "topk_edges": st.session_state.topk_edges,
364
    #     "dic_source_graph": [
365
    #         {
366
    #             "name": st.session_state.config["kg_name"],
367
    #             "kg_pyg_path": st.session_state.config["kg_pyg_path"],
368
    #             "kg_text_path": st.session_state.config["kg_text_path"],
369
    #         }
370
    #     ]}
371
    # )
372
    if agent == "T2AA4P":
373
        app.update_state(config, update_state_t2b(st) | update_state_t2kg(st))
374
    elif agent == "T2B":
375
        app.update_state(config, update_state_t2b(st))
376
    elif agent == "T2KG":
377
        app.update_state(config, update_state_t2kg(st))
378
379
    ERROR_FLAG = False
380
    with collect_runs() as cb:
381
        # Add Langsmith tracer
382
        tracer = LangChainTracer(project_name=st.session_state.project_name)
383
        # Get response from the agent
384
        if current_state.values["llm_model"]._llm_type == "chat-nvidia-ai-playground":
385
            response = app.invoke(
386
                {"messages": [HumanMessage(content=prompt)]},
387
                config=config | {"callbacks": [tracer]},
388
                # stream_mode="messages"
389
            )
390
            # Get the current state of the graph
391
            current_state = app.get_state(config)
392
            # Get last response's AI messages
393
            assistant_content = get_ai_messages(current_state)
394
            # st.markdown(response["messages"][-1].content)
395
            st.write(assistant_content)
396
        else:
397
            response = app.stream(
398
                {"messages": [HumanMessage(content=prompt)]},
399
                config=config | {"callbacks": [tracer]},
400
                stream_mode="messages",
401
            )
402
            st.write_stream(stream_response(response))
403
        # print (cb.traced_runs)
404
        # Save the run id and use to save the feedback
405
        st.session_state.run_id = cb.traced_runs[-1].id
406
407
    # Get the current state of the graph
408
    current_state = app.get_state(config)
409
    # Get last response's AI messages
410
    assistant_content = get_ai_messages(current_state)
411
    # # Get all the AI msgs in the
412
    # # last response from the state
413
    # assistant_content = []
414
    # for msg in current_state.values["messages"][::-1]:
415
    #     if isinstance(msg, HumanMessage):
416
    #         break
417
    #     if isinstance(msg, AIMessage) and msg.content != '':
418
    #         assistant_content.append(msg.content)
419
    #         continue
420
    # # Reverse the order
421
    # assistant_content = assistant_content[::-1]
422
    # # Join the messages
423
    # assistant_content = '\n'.join(assistant_content)
424
    # Add response to chat history
425
    assistant_msg = ChatMessage(
426
        # response["messages"][-1].content,
427
        # current_state.values["messages"][-1].content,
428
        assistant_content,
429
        role="assistant",
430
    )
431
    st.session_state.messages.append({"type": "message", "content": assistant_msg})
432
    # # Display the response in the chat
433
    # st.markdown(response["messages"][-1].content)
434
    st.empty()
435
    # Get the current state of the graph
436
    current_state = app.get_state(config)
437
    # Get the messages from the current state
438
    # and reverse the order
439
    reversed_messages = current_state.values["messages"][::-1]
440
    # Loop through the reversed messages until a
441
    # HumanMessage is found i.e. the last message
442
    # from the user. This is to display the results
443
    # of the tool calls made by the agent since the
444
    # last message from the user.
445
    for msg in reversed_messages:
446
        # print (msg)
447
        # Break the loop if the message is a HumanMessage
448
        # i.e. the last message from the user
449
        if isinstance(msg, HumanMessage):
450
            break
451
        # Skip the message if it is an AIMessage
452
        # i.e. a message from the agent. An agent
453
        # may make multiple tool calls before the
454
        # final response to the user.
455
        if isinstance(msg, AIMessage):
456
            # print ('AIMessage', msg)
457
            continue
458
        # Work on the message if it is a ToolMessage
459
        # These may contain additional visuals that
460
        # need to be displayed to the user.
461
        # print("ToolMessage", msg)
462
        # Skip the Tool message if it is an error message
463
        if msg.status == "error":
464
            continue
465
        # Create a unique message id to identify the tool call
466
        # msg.name is the name of the tool
467
        # msg.tool_call_id is the unique id of the tool call
468
        # st.session_state.run_id is the unique id of the run
469
        uniq_msg_id = (
470
            msg.name + "_" + msg.tool_call_id + "_" + str(st.session_state.run_id)
471
        )
472
        print(uniq_msg_id)
473
        if msg.name in ["simulate_model", "custom_plotter"]:
474
            if msg.name == "simulate_model":
475
                print(
476
                    "-",
477
                    len(current_state.values["dic_simulated_data"]),
478
                    "simulate_model",
479
                )
480
                # Convert the simulated data to a single dictionary
481
                dic_simulated_data = {}
482
                for data in current_state.values["dic_simulated_data"]:
483
                    for key in data:
484
                        if key not in dic_simulated_data:
485
                            dic_simulated_data[key] = []
486
                        dic_simulated_data[key] += [data[key]]
487
                # Create a pandas dataframe from the dictionary
488
                df_simulated_data = pd.DataFrame.from_dict(dic_simulated_data)
489
                # Get the simulated data for the current tool call
490
                df_simulated = pd.DataFrame(
491
                    df_simulated_data[
492
                        df_simulated_data["tool_call_id"] == msg.tool_call_id
493
                    ]["data"].iloc[0]
494
                )
495
                df_selected = df_simulated
496
            elif msg.name == "custom_plotter":
497
                if msg.artifact:
498
                    df_selected = pd.DataFrame.from_dict(msg.artifact["dic_data"])
499
                    # print (df_selected)
500
                else:
501
                    continue
502
            # Display the talbe and plotly chart
503
            render_table_plotly(
504
                uniq_msg_id,
505
                msg.content,
506
                df_selected,
507
                x_axis_label=msg.artifact["x_axis_label"],
508
                y_axis_label=msg.artifact["y_axis_label"],
509
            )
510
        elif msg.name == "steady_state":
511
            if not msg.artifact:
512
                continue
513
            # Create a pandas dataframe from the dictionary
514
            df_selected = pd.DataFrame.from_dict(msg.artifact["dic_data"])
515
            # Make column 'species_name' the index
516
            df_selected.set_index("species_name", inplace=True)
517
            # Display the toggle button to suppress the table
518
            render_toggle(
519
                key="toggle_table_" + uniq_msg_id,
520
                toggle_text="Show Table",
521
                toggle_state=True,
522
                save_toggle=True,
523
            )
524
            # Display the table
525
            render_table(df_selected, key="dataframe_" + uniq_msg_id, save_table=True)
526
        elif msg.name == "search_models":
527
            if not msg.artifact:
528
                continue
529
            # Create a pandas dataframe from the dictionary
530
            df_selected = pd.DataFrame.from_dict(msg.artifact["dic_data"])
531
            # Pick selected columns
532
            df_selected = df_selected[["url", "name", "format", "submissionDate"]]
533
            # Display the toggle button to suppress the table
534
            render_toggle(
535
                key="toggle_table_" + uniq_msg_id,
536
                toggle_text="Show Table",
537
                toggle_state=True,
538
                save_toggle=True,
539
            )
540
            # Display the table
541
            st.dataframe(
542
                df_selected,
543
                use_container_width=True,
544
                key="dataframe_" + uniq_msg_id,
545
                hide_index=True,
546
                column_config={
547
                    "url": st.column_config.LinkColumn(
548
                        label="ID",
549
                        help="Click to open the link associated with the Id",
550
                        validate=r"^http://.*$",  # Ensure the link is valid
551
                        display_text=r"^https://www.ebi.ac.uk/biomodels/(.*?)$",
552
                    ),
553
                    "name": st.column_config.TextColumn("Name"),
554
                    "format": st.column_config.TextColumn("Format"),
555
                    "submissionDate": st.column_config.TextColumn("Submission Date"),
556
                },
557
            )
558
            # Add data to the chat history
559
            st.session_state.messages.append(
560
                {
561
                    "type": "dataframe",
562
                    "content": df_selected,
563
                    "key": "dataframe_" + uniq_msg_id,
564
                    "tool_name": msg.name,
565
                }
566
            )
567
568
        elif msg.name == "parameter_scan":
569
            # Convert the scanned data to a single dictionary
570
            dic_scanned_data = {}
571
            for data in current_state.values["dic_scanned_data"]:
572
                for key in data:
573
                    if key not in dic_scanned_data:
574
                        dic_scanned_data[key] = []
575
                    dic_scanned_data[key] += [data[key]]
576
            # Create a pandas dataframe from the dictionary
577
            df_scanned_data = pd.DataFrame.from_dict(dic_scanned_data)
578
            # Get the scanned data for the current tool call
579
            df_scanned_current_tool_call = pd.DataFrame(
580
                df_scanned_data[df_scanned_data["tool_call_id"] == msg.tool_call_id]
581
            )
582
            # df_scanned_current_tool_call.drop_duplicates()
583
            # print (df_scanned_current_tool_call)
584
            for count in range(0, len(df_scanned_current_tool_call.index)):
585
                # Get the scanned data for the current tool call
586
                df_selected = pd.DataFrame(
587
                    df_scanned_data[
588
                        df_scanned_data["tool_call_id"] == msg.tool_call_id
589
                    ]["data"].iloc[count]
590
                )
591
                # Display the toggle button to suppress the table
592
                render_table_plotly(
593
                    uniq_msg_id + "_" + str(count),
594
                    df_scanned_current_tool_call["name"].iloc[count],
595
                    df_selected,
596
                    x_axis_label=msg.artifact["x_axis_label"],
597
                    y_axis_label=msg.artifact["y_axis_label"],
598
                )
599
        elif msg.name in ["get_annotation"]:
600
            if not msg.artifact:
601
                continue
602
            # Convert the annotated data to a single dictionary
603
            # print ('-', len(current_state.values["dic_annotations_data"]))
604
            dic_annotations_data = {}
605
            for data in current_state.values["dic_annotations_data"]:
606
                # print (data)
607
                for key in data:
608
                    if key not in dic_annotations_data:
609
                        dic_annotations_data[key] = []
610
                    dic_annotations_data[key] += [data[key]]
611
            df_annotations_data = pd.DataFrame.from_dict(dic_annotations_data)
612
            # Get the annotated data for the current tool call
613
            df_selected = pd.DataFrame(
614
                df_annotations_data[
615
                    df_annotations_data["tool_call_id"] == msg.tool_call_id
616
                ]["data"].iloc[0]
617
            )
618
            # print (df_selected)
619
            df_selected["Id"] = df_selected.apply(
620
                lambda row: row["Link"], axis=1  # Ensure "Id" has the correct links
621
            )
622
            df_selected = df_selected.drop(columns=["Link"])
623
            # Directly use the "Link" column for the "Id" column
624
            render_toggle(
625
                key="toggle_table_" + uniq_msg_id,
626
                toggle_text="Show Table",
627
                toggle_state=True,
628
                save_toggle=True,
629
            )
630
            st.dataframe(
631
                df_selected,
632
                use_container_width=True,
633
                key="dataframe_" + uniq_msg_id,
634
                hide_index=True,
635
                column_config={
636
                    "Id": st.column_config.LinkColumn(
637
                        label="Id",
638
                        help="Click to open the link associated with the Id",
639
                        validate=r"^http://.*$",  # Ensure the link is valid
640
                        display_text=r"^http://identifiers\.org/(.*?)$",
641
                    ),
642
                    "Species Name": st.column_config.TextColumn("Species Name"),
643
                    "Description": st.column_config.TextColumn("Description"),
644
                    "Database": st.column_config.TextColumn("Database"),
645
                },
646
            )
647
            # Add data to the chat history
648
            st.session_state.messages.append(
649
                {
650
                    "type": "dataframe",
651
                    "content": df_selected,
652
                    "key": "dataframe_" + uniq_msg_id,
653
                    "tool_name": msg.name,
654
                }
655
            )
656
        elif msg.name in ["subgraph_extraction"]:
657
            print(
658
                "-",
659
                len(current_state.values["dic_extracted_graph"]),
660
                "subgraph_extraction",
661
            )
662
            # Add the graph into the visuals list
663
            latest_graph = current_state.values["dic_extracted_graph"][-1]
664
            if current_state.values["dic_extracted_graph"]:
665
                graphs_visuals.append(
666
                    {
667
                        "content": latest_graph["graph_dict"],
668
                        "key": "subgraph_" + uniq_msg_id,
669
                    }
670
                )
671
        elif msg.name in ["display_results"]:
672
            # This is a tool of T2S agent's sub-agent S2
673
            dic_papers = msg.artifact
674
            if not dic_papers:
675
                continue
676
            df_papers = pd.DataFrame.from_dict(dic_papers, orient="index")
677
            # Add index as a column "key"
678
            df_papers["Key"] = df_papers.index
679
            # Drop index
680
            df_papers.reset_index(drop=True, inplace=True)
681
            # Drop colum abstract
682
            # Define the columns to drop
683
            columns_to_drop = [
684
                "Abstract",
685
                "Key",
686
                "arxiv_id",
687
                "semantic_scholar_paper_id",
688
            ]
689
690
            # Check if columns exist before dropping
691
            existing_columns = [
692
                col for col in columns_to_drop if col in df_papers.columns
693
            ]
694
695
            if existing_columns:
696
                df_papers.drop(columns=existing_columns, inplace=True)
697
698
            if "Year" in df_papers.columns:
699
                df_papers["Year"] = df_papers["Year"].apply(
700
                    lambda x: (
701
                        str(int(x)) if pd.notna(x) and str(x).isdigit() else None
702
                    )
703
                )
704
705
            if "Date" in df_papers.columns:
706
                df_papers["Date"] = df_papers["Date"].apply(
707
                    lambda x: (
708
                        pd.to_datetime(x, errors="coerce").strftime("%Y-%m-%d")
709
                        if pd.notna(pd.to_datetime(x, errors="coerce"))
710
                        else None
711
                    )
712
                )
713
714
            st.dataframe(
715
                df_papers,
716
                hide_index=True,
717
                column_config={
718
                    "URL": st.column_config.LinkColumn(
719
                        display_text="Open",
720
                    ),
721
                },
722
            )
723
            # Add data to the chat history
724
            st.session_state.messages.append(
725
                {
726
                    "type": "dataframe",
727
                    "content": df_papers,
728
                    "key": "dataframe_" + uniq_msg_id,
729
                    "tool_name": msg.name,
730
                }
731
            )
732
            st.empty()
733
734
735
def render_graph(graph_dict: dict, key: str, save_graph: bool = False):
736
    """
737
    Function to render the graph in the chat.
738
739
    Args:
740
        graph_dict: The graph dictionary
741
        key: The key for the graph
742
        save_graph: Whether to save the graph in the chat history
743
    """
744
    # Create a directed graph
745
    graph = nx.DiGraph()
746
747
    # Add nodes with attributes
748
    for node, attrs in graph_dict["nodes"]:
749
        graph.add_node(node, **attrs)
750
751
    # Add edges with attributes
752
    for source, target, attrs in graph_dict["edges"]:
753
        graph.add_edge(source, target, **attrs)
754
755
    # Render the graph
756
    fig = gravis.d3(
757
        graph,
758
        node_size_factor=3.0,
759
        show_edge_label=True,
760
        edge_label_data_source="label",
761
        edge_curvature=0.25,
762
        zoom_factor=1.0,
763
        many_body_force_strength=-500,
764
        many_body_force_theta=0.3,
765
        node_hover_neighborhood=True,
766
        # layout_algorithm_active=True,
767
    )
768
    components.html(fig.to_html(), height=475)
769
770
    if save_graph:
771
        # Add data to the chat history
772
        st.session_state.messages.append(
773
            {
774
                "type": "graph",
775
                "content": graph_dict,
776
                "key": key,
777
            }
778
        )
779
780
781
def get_text_embedding_model(model_name) -> Embeddings:
782
    """
783
    Function to get the text embedding model.
784
785
    Args:
786
        model_name: str: The name of the model
787
788
    Returns:
789
        Embeddings: The text embedding model
790
    """
791
    dic_text_embedding_models = {
792
        "NVIDIA/llama-3.2-nv-embedqa-1b-v2": "nvidia/llama-3.2-nv-embedqa-1b-v2",
793
        "OpenAI/text-embedding-ada-002": "text-embedding-ada-002",
794
    }
795
    if model_name.startswith("NVIDIA"):
796
        return NVIDIAEmbeddings(model=dic_text_embedding_models[model_name])
797
    return OpenAIEmbeddings(model=dic_text_embedding_models[model_name])
798
799
800
def get_base_chat_model(model_name) -> BaseChatModel:
801
    """
802
    Function to get the base chat model.
803
804
    Args:
805
        model_name: str: The name of the model
806
807
    Returns:
808
        BaseChatModel: The base chat model
809
    """
810
    dic_llm_models = {
811
        "NVIDIA/llama-3.3-70b-instruct": "meta/llama-3.3-70b-instruct",
812
        "NVIDIA/llama-3.1-405b-instruct": "meta/llama-3.1-405b-instruct",
813
        "NVIDIA/llama-3.1-70b-instruct": "meta/llama-3.1-70b-instruct",
814
        "OpenAI/gpt-4o-mini": "gpt-4o-mini",
815
    }
816
    if model_name.startswith("Llama"):
817
        return ChatOllama(model=dic_llm_models[model_name], temperature=0)
818
    elif model_name.startswith("NVIDIA"):
819
        return ChatNVIDIA(model=dic_llm_models[model_name], temperature=0)
820
    return ChatOpenAI(model=dic_llm_models[model_name], temperature=0)
821
822
823
@st.dialog("Warning ⚠️")
824
def update_llm_model():
825
    """
826
    Function to update the LLM model.
827
    """
828
    llm_model = st.session_state.llm_model
829
    st.warning(
830
        f"Clicking 'Continue' will reset all agents, \
831
            set the selected LLM to {llm_model}. \
832
            This action will reset the entire app, \
833
            and agents will lose access to the \
834
            conversation history. Are you sure \
835
            you want to proceed?"
836
    )
837
    if st.button("Continue"):
838
        # st.session_state.vote = {"item": item, "reason": reason}
839
        # st.rerun()
840
        # Delete all the items in Session state
841
        for key in st.session_state.keys():
842
            if key in ["messages", "app"]:
843
                del st.session_state[key]
844
        st.rerun()
845
846
847
def update_text_embedding_model(app):
848
    """
849
    Function to update the text embedding model.
850
851
    Args:
852
        app: The LangGraph app
853
    """
854
    config = {"configurable": {"thread_id": st.session_state.unique_id}}
855
    app.update_state(
856
        config,
857
        {
858
            "text_embedding_model": get_text_embedding_model(
859
                st.session_state.text_embedding_model
860
            )
861
        },
862
    )
863
864
865
@st.dialog("Get started with Talk2Biomodels 🚀")
866
def help_button():
867
    """
868
    Function to display the help dialog.
869
    """
870
    st.markdown(
871
        """I am an AI agent designed to assist you with biological
872
modeling and simulations. I can assist with tasks such as:
873
1. Search specific models in the BioModels database.
874
875
```
876
Search models on Crohns disease
877
```
878
879
2. Extract information about models, including species, parameters, units,
880
name and descriptions.
881
882
```
883
Briefly describe model 537 and 
884
its parameters related to drug dosage
885
```
886
887
3. Simulate models:
888
    - Run simulations of models to see how they behave over time.
889
    - Set the duration and the interval.
890
    - Specify which species/parameters you want to include and their starting concentrations/values.
891
    - Include recurring events.
892
893
```
894
Simulate the model 537 for 2016 hours and
895
intervals 300 with an initial value
896
of `DoseQ2W` set to 300 and `Dose` set to 0.
897
```
898
899
4. Answer questions about simulation results.
900
901
```
902
What is the concentration of species IL6 in serum
903
at the end of simulation?
904
```
905
906
5. Create custom plots to visualize the simulation results.
907
908
```
909
Plot the concentration of all
910
the interleukins over time.
911
```
912
913
6. Bring a model to a steady state and determine the concentration of a species at the steady state.
914
915
```
916
Bring BioModel 27 to a steady state,
917
and then determine the Mpp concentration
918
at the steady state.
919
```
920
921
7. Perform parameter scans to determine the effect of changing parameters on the model behavior.
922
923
```
924
How does the value of Pyruvate change in
925
model 64 if the concentration of Extracellular Glucose
926
is changed from 10 to 100 with a step size of 10?
927
The simulation should run for 5 time units with an
928
interval of 10.
929
```
930
931
8. Check out the [Use Cases](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/cases/Case_1/)
932
for more examples, and the [FAQs](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/faq/)
933
for common questions.
934
935
9. Provide feedback to the developers by clicking on the feedback button.
936
                
937
"""
938
    )
939
940
941
def apply_css():
942
    """
943
    Function to apply custom CSS for streamlit app.
944
    """
945
    # Styling using CSS
946
    st.markdown(
947
        """<style>
948
        .stFileUploaderFile { display: none;}
949
        #stFileUploaderPagination { display: none;}
950
        .st-emotion-cache-wbtvu4 { display: none;}
951
        </style>
952
        """,
953
        unsafe_allow_html=True,
954
    )
955
956
957
def get_file_type_icon(file_type: str) -> str:
958
    """
959
    Function to get the icon for the file type.
960
961
    Args:
962
        file_type (str): The file type.
963
964
    Returns:
965
        str: The icon for the file type.
966
    """
967
    return {"drug_data": "💊", "endotype": "🧬", "sbml_file": "📜"}.get(file_type)
968
969
970
@st.fragment
971
def get_t2b_uploaded_files(app):
972
    """
973
    Upload files for T2B agent.
974
    """
975
    # Upload the XML/SBML file
976
    uploaded_sbml_file = st.file_uploader(
977
        "Upload an XML/SBML file",
978
        accept_multiple_files=False,
979
        type=["xml", "sbml"],
980
        help="Upload a QSP as an XML/SBML file",
981
    )
982
983
    # Upload the article
984
    article = st.file_uploader(
985
        "Upload an article",
986
        help="Upload a PDF article to ask questions.",
987
        accept_multiple_files=False,
988
        type=["pdf"],
989
        key="article",
990
    )
991
    # Update the agent state with the uploaded article
992
    if article:
993
        # print (article.name)
994
        with tempfile.NamedTemporaryFile(delete=False) as f:
995
            f.write(article.read())
996
        # Create config for the agent
997
        config = {"configurable": {"thread_id": st.session_state.unique_id}}
998
        # Update the agent state with the selected LLM model
999
        app.update_state(config, {"pdf_file_name": f.name})
1000
    # Return the uploaded file
1001
    return uploaded_sbml_file
1002
1003
1004
@st.fragment
1005
def get_uploaded_files(cfg: hydra.core.config_store.ConfigStore) -> None:
1006
    """
1007
    Upload files to a directory set in cfg.upload_data_dir, and display them in the UI.
1008
1009
    Args:
1010
        cfg: The configuration object.
1011
    """
1012
    # sbml_file = st.file_uploader("📜 Upload SBML file",
1013
    #     accept_multiple_files=False,
1014
    #     help='Upload an ODE model in SBML format.',
1015
    #     type=["xml", "sbml"],
1016
    #     key=f"uploader_sbml_file_{st.session_state.sbml_key}")
1017
1018
    data_package_files = st.file_uploader(
1019
        "💊 Upload pre-clinical drug data",
1020
        help="Free-form text. Must contain atleast drug targets and kinetic parameters",
1021
        accept_multiple_files=True,
1022
        type=cfg.data_package_allowed_file_types,
1023
        key=f"uploader_{st.session_state.data_package_key}",
1024
    )
1025
1026
    endotype_files = st.file_uploader(
1027
        "🧬 Upload endotype data",
1028
        help="Free-form text. List of differentially expressed genes",
1029
        accept_multiple_files=True,
1030
        type=cfg.endotype_allowed_file_types,
1031
        key=f"uploader_endotype_{st.session_state.endotype_key}",
1032
    )
1033
1034
    # Merge the uploaded files
1035
    uploaded_files = data_package_files.copy()
1036
    if endotype_files:
1037
        uploaded_files += endotype_files.copy()
1038
    # if sbml_file:
1039
    #     uploaded_files += [sbml_file]
1040
1041
    with st.spinner("Storing uploaded file(s) ..."):
1042
        # for uploaded_file in data_package_files:
1043
        for uploaded_file in uploaded_files:
1044
            if uploaded_file.name not in [
1045
                uf["file_name"] for uf in st.session_state.uploaded_files
1046
            ]:
1047
                current_timestamp = datetime.datetime.now().strftime(
1048
                    "%Y-%m-%d %H:%M:%S"
1049
                )
1050
                uploaded_file.file_name = uploaded_file.name
1051
                uploaded_file.file_path = (
1052
                    f"{cfg.upload_data_dir}/{uploaded_file.file_name}"
1053
                )
1054
                uploaded_file.current_user = st.session_state.current_user
1055
                uploaded_file.timestamp = current_timestamp
1056
                if uploaded_file.name in [uf.name for uf in data_package_files]:
1057
                    uploaded_file.file_type = "drug_data"
1058
                elif uploaded_file.name in [uf.name for uf in endotype_files]:
1059
                    uploaded_file.file_type = "endotype"
1060
                else:
1061
                    uploaded_file.file_type = "sbml_file"
1062
                st.session_state.uploaded_files.append(
1063
                    {
1064
                        "file_name": uploaded_file.file_name,
1065
                        "file_path": uploaded_file.file_path,
1066
                        "file_type": uploaded_file.file_type,
1067
                        "uploaded_by": uploaded_file.current_user,
1068
                        "uploaded_timestamp": uploaded_file.timestamp,
1069
                    }
1070
                )
1071
                with open(
1072
                    os.path.join(cfg.upload_data_dir, uploaded_file.file_name), "wb"
1073
                ) as f:
1074
                    f.write(uploaded_file.getbuffer())
1075
                uploaded_file = None
1076
1077
    # Display uploaded files and provide a remove button
1078
    for uploaded_file in st.session_state.uploaded_files:
1079
        col1, col2 = st.columns([4, 1])
1080
        with col1:
1081
            st.write(
1082
                get_file_type_icon(uploaded_file["file_type"])
1083
                + uploaded_file["file_name"]
1084
            )
1085
        with col2:
1086
            if st.button("🗑️", key=uploaded_file["file_name"]):
1087
                with st.spinner("Removing uploaded file ..."):
1088
                    if os.path.isfile(
1089
                        f"{cfg.upload_data_dir}/{uploaded_file['file_name']}"
1090
                    ):
1091
                        os.remove(f"{cfg.upload_data_dir}/{uploaded_file['file_name']}")
1092
                    st.session_state.uploaded_files.remove(uploaded_file)
1093
                    st.cache_data.clear()
1094
                    st.session_state.data_package_key += 1
1095
                    st.session_state.endotype_key += 1
1096
                    st.rerun(scope="fragment")