a b/app/frontend/streamlit_app_talk2aiagents4pharma.py
1
#!/usr/bin/env python3
2
3
'''
4
A Streamlit app for the Talk2AIAgents4Pharma graph.
5
'''
6
7
import os
8
import sys
9
import random
10
import hydra
11
import streamlit as st
12
from streamlit_feedback import streamlit_feedback
13
from langchain_openai import ChatOpenAI
14
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
15
from langchain_core.messages import ChatMessage
16
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
from utils import streamlit_utils
18
19
st.set_page_config(page_title="Talk2AIAgents4Pharma",
20
                   page_icon="🤖",
21
                   layout="wide",
22
                   initial_sidebar_state="collapsed",)
23
24
# Set the logo
25
st.logo(
26
    image='docs/assets/VPE.png',
27
    size='large',
28
    link='https://github.com/VirtualPatientEngine'
29
)
30
31
# Check if env variables OPENAI_API_KEY and/or
32
# NVIDIA_API_KEY exist
33
if "OPENAI_API_KEY" not in os.environ or "NVIDIA_API_KEY" not in os.environ:
34
    st.error("Please set the OPENAI_API_KEY and NVIDIA_API_KEY "
35
             "environment variables in the terminal where you run "
36
             "the app. For more information, please refer to our "
37
             "[documentation](https://virtualpatientengine.github.io/AIAgents4Pharma/#option-2-git).")
38
    st.stop()
39
40
# Import the agent
41
sys.path.append('./')
42
from aiagents4pharma.talk2aiagents4pharma.agents.main_agent import get_app
43
44
# Initialize configuration
45
hydra.core.global_hydra.GlobalHydra.instance().clear()
46
if "config" not in st.session_state:
47
    # Load Hydra configuration
48
    with hydra.initialize(
49
        version_base=None,
50
        config_path="../../aiagents4pharma/talk2knowledgegraphs/configs",
51
    ):
52
        cfg_t2kg = hydra.compose(config_name="config", overrides=["app/frontend=default"])
53
        cfg_t2kg = cfg_t2kg.app.frontend
54
        st.session_state.config = cfg_t2kg
55
else:
56
    cfg_t2kg = st.session_state.config
57
58
########################################################################################
59
# Streamlit app
60
########################################################################################
61
# Create a chat prompt template
62
prompt = ChatPromptTemplate.from_messages([
63
        ("system", "Welcome to Talk2AIAgents4Pharma!"),
64
        MessagesPlaceholder(variable_name='chat_history', optional=True),
65
        ("human", "{input}"),
66
        ("placeholder", "{agent_scratchpad}"),
67
])
68
69
# Initialize current user
70
if "current_user" not in st.session_state:
71
    st.session_state.current_user = cfg_t2kg.default_user
72
73
# Initialize chat history
74
if "messages" not in st.session_state:
75
    st.session_state.messages = []
76
77
## T2B
78
79
# Initialize sbml_file_path
80
if "sbml_file_path" not in st.session_state:
81
    st.session_state.sbml_file_path = None
82
83
## T2KG
84
85
# Initialize session state for pre-clinical data package uploader
86
if "data_package_key" not in st.session_state:
87
    st.session_state.data_package_key = 0
88
89
# Initialize session state for patient gene expression data uploader
90
if "endotype_key" not in st.session_state:
91
    st.session_state.endotype_key = 0
92
93
# Initialize session state for uploaded files
94
if "uploaded_files" not in st.session_state:
95
    st.session_state.uploaded_files = []
96
97
    # Make directories if not exists
98
    os.makedirs(cfg_t2kg.upload_data_dir, exist_ok=True)
99
100
# Initialize project_name for Langsmith
101
if "project_name" not in st.session_state:
102
    # st.session_state.project_name = str(st.session_state.user_name) + '@' + str(uuid.uuid4())
103
    st.session_state.project_name = 'T2AA4P-' + str(random.randint(1000, 9999))
104
105
# Initialize run_id for Langsmith
106
if "run_id" not in st.session_state:
107
    st.session_state.run_id = None
108
109
# Initialize graph
110
if "unique_id" not in st.session_state:
111
    st.session_state.unique_id = random.randint(1, 1000)
112
if "app" not in st.session_state:
113
    if "llm_model" not in st.session_state:
114
        st.session_state.app = get_app(st.session_state.unique_id,
115
            llm_model=ChatOpenAI(model='gpt-4o-mini',
116
                       temperature=0))
117
    else:
118
        print (st.session_state.llm_model)
119
        st.session_state.app = get_app(st.session_state.unique_id,
120
                            llm_model=streamlit_utils.get_base_chat_model(
121
                                st.session_state.llm_model))
122
123
if "topk_nodes" not in st.session_state:
124
    # Subgraph extraction settings
125
    st.session_state.topk_nodes = cfg_t2kg.reasoning_subgraph_topk_nodes
126
    st.session_state.topk_edges = cfg_t2kg.reasoning_subgraph_topk_edges
127
128
# Get the app
129
app = st.session_state.app
130
131
# Apply custom CSS
132
streamlit_utils.apply_css()
133
134
# Sidebar
135
with st.sidebar:
136
    st.markdown("**âš™ī¸ Subgraph Extraction Settings**")
137
    topk_nodes = st.slider(
138
        "Top-K (Nodes)",
139
        cfg_t2kg.reasoning_subgraph_topk_nodes_min,
140
        cfg_t2kg.reasoning_subgraph_topk_nodes_max,
141
        st.session_state.topk_nodes,
142
        key="st_slider_topk_nodes",
143
    )
144
    st.session_state.topk_nodes = topk_nodes
145
    topk_edges = st.slider(
146
        "Top-K (Edges)",
147
        cfg_t2kg.reasoning_subgraph_topk_nodes_min,
148
        cfg_t2kg.reasoning_subgraph_topk_nodes_max,
149
        st.session_state.topk_edges,
150
        key="st_slider_topk_edges",
151
    )
152
    st.session_state.topk_edges = topk_edges
153
154
# Main layout of the app split into two columns
155
main_col1, main_col2 = st.columns([3, 7])
156
# First column
157
with main_col1:
158
    with st.container(border=True):
159
        # Title
160
        st.write("""
161
            <h3 style='margin: 0px; padding-bottom: 10px; font-weight: bold;'>
162
            Talk2AIAgents4Pharma
163
            </h3>
164
            """,
165
            unsafe_allow_html=True)
166
167
        # LLM model panel
168
        llms = ["OpenAI/gpt-4o-mini",
169
                "NVIDIA/llama-3.3-70b-instruct",
170
                "NVIDIA/llama-3.1-70b-instruct",
171
                "NVIDIA/llama-3.1-405b-instruct"]
172
        st.selectbox(
173
            "Pick an LLM to power the agent",
174
            llms,
175
            index=0,
176
            key="llm_model",
177
            on_change=streamlit_utils.update_llm_model,
178
            help="Used for tool calling and generating responses."
179
        )
180
181
        # Text embedding model panel
182
        text_models = ["NVIDIA/llama-3.2-nv-embedqa-1b-v2",
183
                       "OpenAI/text-embedding-ada-002"]
184
        st.selectbox(
185
            "Pick a text embedding model",
186
            text_models,
187
            index=0,
188
            key="text_embedding_model",
189
            on_change=streamlit_utils.update_text_embedding_model,
190
            kwargs={"app": app},
191
            help="Used for Retrival Augmented Generation (RAG) and other tasks."
192
        )
193
194
        # T2B Upload files
195
        uploaded_sbml_file = streamlit_utils.get_t2b_uploaded_files(app)
196
197
        # T2KG Upload files
198
        streamlit_utils.get_uploaded_files(cfg_t2kg)
199
200
        # Help text
201
        st.button("Know more ↗",
202
                #   icon="â„šī¸",
203
                  on_click=streamlit_utils.help_button,
204
                  use_container_width=False)
205
206
    with st.container(border=False, height=500):
207
        prompt = st.chat_input("Say something ...", key="st_chat_input")
208
209
# Second column
210
with main_col2:
211
    # Chat history panel
212
    with st.container(border=True, height=600):
213
        st.write("#### đŸ’Ŧ Chat History")
214
215
        # Display history of messages
216
        for count, message in enumerate(st.session_state.messages):
217
            if message["type"] == "message":
218
                with st.chat_message(message["content"].role,
219
                                     avatar="🤖"
220
                                     if message["content"].role != 'user'
221
                                     else "👩đŸģ‍đŸ’ģ"):
222
                    st.markdown(message["content"].content)
223
                    st.empty()
224
            elif message["type"] == "button":
225
                if st.button(message["content"],
226
                             key=message["key"]):
227
                    # Trigger the question
228
                    prompt = message["question"]
229
                    st.empty()
230
            elif message["type"] == "plotly":
231
                streamlit_utils.render_plotly(message["content"],
232
                              key=message["key"],
233
                              title=message["title"],
234
                              y_axis_label=message["y_axis_label"],
235
                              x_axis_label=message["x_axis_label"],
236
                            #   tool_name=message["tool_name"],
237
                              save_chart=False)
238
                st.empty()
239
            elif message["type"] == "toggle":
240
                streamlit_utils.render_toggle(key=message["key"],
241
                                    toggle_text=message["content"],
242
                                    toggle_state=message["toggle_state"],
243
                                    save_toggle=False)
244
                st.empty()
245
            elif message["type"] == "graph":
246
                streamlit_utils.render_graph(
247
                    message["content"], key=message["key"], save_graph=False
248
                )
249
                st.empty()
250
            elif message["type"] == "dataframe":
251
                if 'tool_name' in message:
252
                    if message['tool_name'] == 'get_annotation':
253
                        df_selected = message["content"]
254
                        st.dataframe(df_selected,
255
                                    use_container_width=True,
256
                                    key=message["key"],
257
                                    hide_index=True,
258
                                    column_config={
259
                                        "Id": st.column_config.LinkColumn(
260
                                            label="Id",
261
                                            help="Click to open the link associated with the Id",
262
                                            validate=r"^http://.*$",  # Ensure the link is valid
263
                                            display_text=r"^http://identifiers\.org/(.*?)$"
264
                                        ),
265
                                        "Species Name": st.column_config.TextColumn("Species Name"),
266
                                        "Description": st.column_config.TextColumn("Description"),
267
                                        "Database": st.column_config.TextColumn("Database"),
268
                                    }
269
                                )
270
                    elif message['tool_name'] == 'search_models':
271
                        df_selected = message["content"]
272
                        st.dataframe(df_selected,
273
                            use_container_width=True,
274
                            key=message["key"],
275
                            hide_index=True,
276
                            column_config={
277
                                "url": st.column_config.LinkColumn(
278
                                    label="ID",
279
                                    help="Click to open the link associated with the Id",
280
                                    validate=r"^http://.*$",  # Ensure the link is valid
281
                                    display_text=r"^https://www.ebi.ac.uk/biomodels/(.*?)$"
282
                                ),
283
                                "name": st.column_config.TextColumn("Name"),
284
                                "format": st.column_config.TextColumn("Format"),
285
                                "submissionDate": st.column_config.TextColumn("Submission Date"),
286
                                }
287
                            )
288
                else:
289
                    streamlit_utils.render_table(message["content"],
290
                                    key=message["key"],
291
                                    # tool_name=message["tool_name"],
292
                                    save_table=False)
293
                st.empty()
294
        # Display intro message only the first time
295
        # i.e. when there are no messages in the chat
296
        if not st.session_state.messages:
297
            with st.chat_message("assistant", avatar="🤖"):
298
                with st.spinner("Initializing the agent ..."):
299
                    config = {"configurable":
300
                                {"thread_id": st.session_state.unique_id}
301
                                }
302
                    # Update the agent state with the selected LLM model
303
                    current_state = app.get_state(config)
304
                    app.update_state(
305
                        config,
306
                        {"llm_model": streamlit_utils.get_base_chat_model(
307
                            st.session_state.llm_model),
308
                        "text_embedding_model": streamlit_utils.get_text_embedding_model(
309
                            st.session_state.text_embedding_model),
310
                        "embedding_model": streamlit_utils.get_text_embedding_model(
311
                            st.session_state.text_embedding_model),
312
                        "uploaded_files": st.session_state.uploaded_files,
313
                        "topk_nodes": st.session_state.topk_nodes,
314
                        "topk_edges": st.session_state.topk_edges,
315
                        "dic_source_graph": [
316
                            {
317
                                "name": st.session_state.config["kg_name"],
318
                                "kg_pyg_path": st.session_state.config["kg_pyg_path"],
319
                                "kg_text_path": st.session_state.config["kg_text_path"],
320
                            }
321
                        ]}
322
                    )
323
                    intro_prompt = "Tell your name and about yourself. Always start with a greeting."
324
                    intro_prompt += " and tell about the tools you can run to perform analysis with short description."
325
                    intro_prompt += " We have provided starter questions (separately) outisde your response."
326
                    intro_prompt += " Do not provide any questions by yourself. Let the users know that they can"
327
                    intro_prompt += " simply click on the questions to execute them."
328
                    # intro_prompt += " Let them know that they can check out the use cases"
329
                    # intro_prompt += " and FAQs described in the link below. Be friendly and helpful."
330
                    # intro_prompt += "\n"
331
                    # intro_prompt += "Here is the link to the use cases: [Use Cases](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/cases/Case_1/)"
332
                    # intro_prompt += "\n"
333
                    # intro_prompt += "Here is the link to the FAQs: [FAQs](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/faq/)"
334
                    response = app.stream(
335
                                    {"messages": [HumanMessage(content=intro_prompt)]},
336
                                    config=config,
337
                                    stream_mode="messages"
338
                                )
339
                    st.write_stream(streamlit_utils.stream_response(response))
340
                    current_state = app.get_state(config)
341
                    # Add response to chat history
342
                    assistant_msg = ChatMessage(
343
                                        current_state.values["messages"][-1].content,
344
                                        role="assistant")
345
                    st.session_state.messages.append({
346
                                    "type": "message",
347
                                    "content": assistant_msg
348
                                })
349
                    st.empty()
350
        if len(st.session_state.messages) <= 1:
351
            for count, question in enumerate(streamlit_utils.sample_questions_t2aa4p()):
352
                if st.button(f'Q{count+1}. {question}',
353
                             key=f'sample_question_{count+1}'):
354
                    # Trigger the question
355
                    prompt = question
356
                # Add button click to chat history
357
                st.session_state.messages.append({
358
                                "type": "button",
359
                                "question": question,
360
                                "content": f'Q{count+1}. {question}',
361
                                "key": f'sample_question_{count+1}'
362
                            })
363
364
        # When the user asks a question
365
        if prompt:
366
            # Create a key 'uploaded_file' to read the uploaded file
367
            if uploaded_sbml_file:
368
                st.session_state.sbml_file_path = uploaded_sbml_file.read().decode("utf-8")
369
370
            # Display user prompt
371
            prompt_msg = ChatMessage(prompt, role="user")
372
            st.session_state.messages.append(
373
                {
374
                    "type": "message",
375
                    "content": prompt_msg
376
                }
377
            )
378
            with st.chat_message("user", avatar="👩đŸģ‍đŸ’ģ"):
379
                st.markdown(prompt)
380
                st.empty()
381
382
            # Auxiliary visualization-related variables
383
            graphs_visuals = []
384
            with st.chat_message("assistant", avatar="🤖"):
385
                # with st.spinner("Fetching response ..."):
386
                with st.spinner():
387
                    # Get chat history
388
                    history = [(m["content"].role, m["content"].content)
389
                                            for m in st.session_state.messages
390
                                            if m["type"] == "message"]
391
                    # Convert chat history to ChatMessage objects
392
                    chat_history = [
393
                        SystemMessage(content=m[1]) if m[0] == "system" else
394
                        HumanMessage(content=m[1]) if m[0] == "human" else
395
                        AIMessage(content=m[1])
396
                        for m in history
397
                    ]
398
399
                    streamlit_utils.get_response('T2AA4P',
400
                                                 graphs_visuals,
401
                                                 app,
402
                                                 st,
403
                                                 prompt)
404
            # Visualize the graph
405
            if len(graphs_visuals) > 0:
406
                for count, graph in enumerate(graphs_visuals):
407
                    streamlit_utils.render_graph(
408
                        graph_dict=graph["content"], key=graph["key"], save_graph=True
409
                    )
410
411
        if st.session_state.get("run_id"):
412
            feedback = streamlit_feedback(
413
                feedback_type="thumbs",
414
                optional_text_label="[Optional] Please provide an explanation",
415
                on_submit=streamlit_utils.submit_feedback,
416
                key=f"feedback_{st.session_state.run_id}"
417
            )