a b/app/frontend/streamlit_app_talk2biomodels.py
1
#!/usr/bin/env python3
2
3
'''
4
Talk2Biomodels: A Streamlit app for the Talk2Biomodels graph.
5
'''
6
7
import os
8
import sys
9
import random
10
import streamlit as st
11
from streamlit_feedback import streamlit_feedback
12
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
13
from langchain_core.messages import ChatMessage
14
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
from langchain_openai import ChatOpenAI
16
from utils import streamlit_utils
17
18
st.set_page_config(page_title="Talk2Biomodels", page_icon="🤖", layout="wide")
19
# Set the logo
20
st.logo(
21
    image='docs/assets/VPE.png',
22
    size='large',
23
    link='https://github.com/VirtualPatientEngine'
24
)
25
26
# Check if env variables OPENAI_API_KEY and/or
27
# NVIDIA_API_KEY exist
28
if "OPENAI_API_KEY" not in os.environ or "NVIDIA_API_KEY" not in os.environ:
29
    st.error("Please set the OPENAI_API_KEY and NVIDIA_API_KEY "
30
             "environment variables in the terminal where you run "
31
             "the app. For more information, please refer to our "
32
             "[documentation](https://virtualpatientengine.github.io/AIAgents4Pharma/#option-2-git).")
33
    st.stop()
34
35
# Import the agent
36
sys.path.append('./')
37
from aiagents4pharma.talk2biomodels.agents.t2b_agent import get_app
38
39
########################################################################################
40
# Streamlit app
41
########################################################################################
42
# Create a chat prompt template
43
prompt = ChatPromptTemplate.from_messages([
44
        ("system", "Welcome to Talk2Biomodels!"),
45
        MessagesPlaceholder(variable_name='chat_history', optional=True),
46
        ("human", "{input}"),
47
        ("placeholder", "{agent_scratchpad}"),
48
])
49
50
# Initialize chat history
51
if "messages" not in st.session_state:
52
    st.session_state.messages = []
53
54
# Initialize sbml_file_path
55
if "sbml_file_path" not in st.session_state:
56
    st.session_state.sbml_file_path = None
57
58
# Initialize project_name for Langsmith
59
if "project_name" not in st.session_state:
60
    # st.session_state.project_name = str(st.session_state.user_name) + '@' + str(uuid.uuid4())
61
    st.session_state.project_name = 'T2B-' + str(random.randint(1000, 9999))
62
63
# Initialize run_id for Langsmith
64
if "run_id" not in st.session_state:
65
    st.session_state.run_id = None
66
67
# Initialize graph
68
if "unique_id" not in st.session_state:
69
    st.session_state.unique_id = random.randint(1, 1000)
70
if "app" not in st.session_state:
71
    if "llm_model" not in st.session_state:
72
        st.session_state.app = get_app(st.session_state.unique_id,
73
                                llm_model=ChatOpenAI(model='gpt-4o-mini', temperature=0))
74
    else:
75
        print (st.session_state.llm_model)
76
        st.session_state.app = get_app(st.session_state.unique_id,
77
                                llm_model=streamlit_utils.get_base_chat_model(
78
                                st.session_state.llm_model))
79
80
# Get the app
81
app = st.session_state.app
82
83
@st.fragment
84
def get_uploaded_files():
85
    """
86
    Upload files.
87
    """
88
    # Upload the XML/SBML file
89
    uploaded_sbml_file = st.file_uploader(
90
        "Upload an XML/SBML file",
91
        accept_multiple_files=False,
92
        type=["xml", "sbml"],
93
        help='Upload a QSP as an XML/SBML file'
94
        )
95
96
    # Upload the article
97
    article = st.file_uploader(
98
        "Upload an article",
99
        help="Upload a PDF article to ask questions.",
100
        accept_multiple_files=False,
101
        type=["pdf"],
102
        key="article"
103
    )
104
    # print (article)
105
    # Update the agent state with the uploaded article
106
    if article:
107
        import tempfile
108
        print (article.name)
109
        with tempfile.NamedTemporaryFile(delete=False) as f:
110
            f.write(article.read())
111
            # print (f.name)
112
        # Create config for the agent
113
        config = {"configurable": {"thread_id": st.session_state.unique_id}}
114
        # Update the agent state with the selected LLM model
115
        app.update_state(
116
            config,
117
            {"pdf_file_name": f.name}
118
        )
119
    # Return the uploaded file
120
    return uploaded_sbml_file
121
122
# Main layout of the app split into two columns
123
main_col1, main_col2 = st.columns([3, 7])
124
# First column
125
with main_col1:
126
    with st.container(border=True):
127
        # Title
128
        st.write("""
129
            <h3 style='margin: 0px; padding-bottom: 10px; font-weight: bold;'>
130
            🤖 Talk2Biomodels
131
            </h3>
132
            """,
133
            unsafe_allow_html=True)
134
135
        # LLM model panel
136
        llms = ["OpenAI/gpt-4o-mini",
137
                "NVIDIA/llama-3.3-70b-instruct",
138
                "NVIDIA/llama-3.1-70b-instruct",
139
                "NVIDIA/llama-3.1-405b-instruct"]
140
        st.selectbox(
141
            "Pick an LLM to power the agent",
142
            llms,
143
            index=0,
144
            key="llm_model",
145
            on_change=streamlit_utils.update_llm_model,
146
            help="Used for tool calling and generating responses."
147
        )
148
149
        # Text embedding model panel
150
        text_models = ["NVIDIA/llama-3.2-nv-embedqa-1b-v2",
151
                       "OpenAI/text-embedding-ada-002"]
152
        st.selectbox(
153
            "Pick a text embedding model",
154
            text_models,
155
            index=0,
156
            key="text_embedding_model",
157
            on_change=streamlit_utils.update_text_embedding_model,
158
            kwargs={"app": app},
159
            help="Used for Retrival Augmented Generation (RAG) and other tasks."
160
        )
161
162
        # Upload files
163
        uploaded_sbml_file = get_uploaded_files()
164
165
        # Help text
166
        st.button("Know more ↗",
167
                #   icon="â„šī¸",
168
                  on_click=streamlit_utils.help_button,
169
                  use_container_width=False)
170
171
    with st.container(border=False, height=500):
172
        prompt = st.chat_input("Say something ...", key="st_chat_input")
173
174
# Second column
175
with main_col2:
176
    # Chat history panel
177
    with st.container(border=True, height=600):
178
        st.write("#### đŸ’Ŧ Chat History")
179
180
        # Display history of messages
181
        for count, message in enumerate(st.session_state.messages):
182
            if message["type"] == "message":
183
                with st.chat_message(message["content"].role,
184
                                     avatar="🤖"
185
                                     if message["content"].role != 'user'
186
                                     else "👩đŸģ‍đŸ’ģ"):
187
                    st.markdown(message["content"].content)
188
                    st.empty()
189
            elif message["type"] == "button":
190
                if st.button(message["content"],
191
                             key=message["key"]):
192
                    # Trigger the question
193
                    prompt = message["question"]
194
                    st.empty()
195
            elif message["type"] == "plotly":
196
                streamlit_utils.render_plotly(message["content"],
197
                              key=message["key"],
198
                              title=message["title"],
199
                              y_axis_label=message["y_axis_label"],
200
                              x_axis_label=message["x_axis_label"],
201
                            #   tool_name=message["tool_name"],
202
                              save_chart=False)
203
                st.empty()
204
            elif message["type"] == "toggle":
205
                streamlit_utils.render_toggle(key=message["key"],
206
                                    toggle_text=message["content"],
207
                                    toggle_state=message["toggle_state"],
208
                                    save_toggle=False)
209
                st.empty()
210
            elif message["type"] == "dataframe":
211
                if 'tool_name' in message:
212
                    if message['tool_name'] == 'get_annotation':
213
                        df_selected = message["content"]
214
                        st.dataframe(df_selected,
215
                                    use_container_width=True,
216
                                    key=message["key"],
217
                                    hide_index=True,
218
                                    column_config={
219
                                        "Id": st.column_config.LinkColumn(
220
                                            label="Id",
221
                                            help="Click to open the link associated with the Id",
222
                                            validate=r"^http://.*$",  # Ensure the link is valid
223
                                            display_text=r"^http://identifiers\.org/(.*?)$"
224
                                        ),
225
                                        "Species Name": st.column_config.TextColumn("Species Name"),
226
                                        "Description": st.column_config.TextColumn("Description"),
227
                                        "Database": st.column_config.TextColumn("Database"),
228
                                    }
229
                                )
230
                    elif message['tool_name'] == 'search_models':
231
                        df_selected = message["content"]
232
                        st.dataframe(df_selected,
233
                            use_container_width=True,
234
                            key=message["key"],
235
                            hide_index=True,
236
                            column_config={
237
                                "url": st.column_config.LinkColumn(
238
                                    label="ID",
239
                                    help="Click to open the link associated with the Id",
240
                                    validate=r"^http://.*$",  # Ensure the link is valid
241
                                    display_text=r"^https://www.ebi.ac.uk/biomodels/(.*?)$"
242
                                ),
243
                                "name": st.column_config.TextColumn("Name"),
244
                                "format": st.column_config.TextColumn("Format"),
245
                                "submissionDate": st.column_config.TextColumn("Submission Date"),
246
                                }
247
                            )
248
                else:
249
                    streamlit_utils.render_table(message["content"],
250
                                    key=message["key"],
251
                                    # tool_name=message["tool_name"],
252
                                    save_table=False)
253
                st.empty()
254
        # Display intro message only the first time
255
        # i.e. when there are no messages in the chat
256
        if not st.session_state.messages:
257
            with st.chat_message("assistant", avatar="🤖"):
258
                with st.spinner("Initializing the agent ..."):
259
                    config = {"configurable":
260
                                {"thread_id": st.session_state.unique_id}
261
                                }
262
                    # Update the agent state with the selected LLM model
263
                    current_state = app.get_state(config)
264
                    app.update_state(
265
                        config,
266
                        {"llm_model": streamlit_utils.get_base_chat_model(
267
                            st.session_state.llm_model),
268
                        "text_embedding_model": streamlit_utils.get_text_embedding_model(
269
                            st.session_state.text_embedding_model)}
270
                    )
271
                    intro_prompt = "Tell your name and about yourself. Always start with a greeting."
272
                    intro_prompt += " and tell about the tools you can run to perform analysis with short description."
273
                    intro_prompt += " We have provided starter questions (separately) outisde your response."
274
                    intro_prompt += " Do not provide any questions by yourself. Let the users know that they can"
275
                    intro_prompt += " simply click on the questions to execute them."
276
                    intro_prompt += " Let them know that they can check out the use cases"
277
                    intro_prompt += " and FAQs described in the link below. Be friendly and helpful."
278
                    intro_prompt += "\n"
279
                    intro_prompt += "Here is the link to the use cases: [Use Cases](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/cases/Case_1/)"
280
                    intro_prompt += "\n"
281
                    intro_prompt += "Here is the link to the FAQs: [FAQs](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/faq/)"
282
                    response = app.stream(
283
                                    {"messages": [HumanMessage(content=intro_prompt)]},
284
                                    config=config,
285
                                    stream_mode="messages"
286
                                )
287
                    st.write_stream(streamlit_utils.stream_response(response))
288
                    current_state = app.get_state(config)
289
                    # Add response to chat history
290
                    assistant_msg = ChatMessage(
291
                                        current_state.values["messages"][-1].content,
292
                                        role="assistant")
293
                    st.session_state.messages.append({
294
                                    "type": "message",
295
                                    "content": assistant_msg
296
                                })
297
                    st.empty()
298
        if len(st.session_state.messages) <= 1:
299
            for count, question in enumerate(streamlit_utils.sample_questions()):
300
                if st.button(f'Q{count+1}. {question}',
301
                             key=f'sample_question_{count+1}'):
302
                    # Trigger the question
303
                    prompt = question
304
                # Add button click to chat history
305
                st.session_state.messages.append({
306
                                "type": "button",
307
                                "question": question,
308
                                "content": f'Q{count+1}. {question}',
309
                                "key": f'sample_question_{count+1}'
310
                            })
311
312
        # When the user asks a question
313
        if prompt:
314
            # Create a key 'uploaded_file' to read the uploaded file
315
            if uploaded_sbml_file:
316
                st.session_state.sbml_file_path = uploaded_sbml_file.read().decode("utf-8")
317
318
            # Display user prompt
319
            prompt_msg = ChatMessage(prompt, role="user")
320
            st.session_state.messages.append(
321
                {
322
                    "type": "message",
323
                    "content": prompt_msg
324
                }
325
            )
326
            with st.chat_message("user", avatar="👩đŸģ‍đŸ’ģ"):
327
                st.markdown(prompt)
328
                st.empty()
329
330
            with st.chat_message("assistant", avatar="🤖"):
331
                # with st.spinner("Fetching response ..."):
332
                with st.spinner():
333
                    # Get chat history
334
                    history = [(m["content"].role, m["content"].content)
335
                                            for m in st.session_state.messages
336
                                            if m["type"] == "message"]
337
                    # Convert chat history to ChatMessage objects
338
                    chat_history = [
339
                        SystemMessage(content=m[1]) if m[0] == "system" else
340
                        HumanMessage(content=m[1]) if m[0] == "human" else
341
                        AIMessage(content=m[1])
342
                        for m in history
343
                    ]
344
345
                    streamlit_utils.get_response('T2B', None, app, st, prompt)
346
347
        if st.session_state.get("run_id"):
348
            feedback = streamlit_feedback(
349
                feedback_type="thumbs",
350
                optional_text_label="[Optional] Please provide an explanation",
351
                on_submit=streamlit_utils.submit_feedback,
352
                key=f"feedback_{st.session_state.run_id}"
353
            )