a b/app/frontend/streamlit_app.py
1
#!/usr/bin/env python3
2
3
'''
4
Talk2BioModels: Interactive BioModel Simulation Tool
5
'''
6
7
import os
8
import sys
9
import random
10
import streamlit as st
11
import pandas as pd
12
import plotly.express as px
13
sys.path.append('./')
14
from aiagents4pharma.talk2biomodels.tools.ask_question import AskQuestionTool
15
from aiagents4pharma.talk2biomodels.tools.simulate_model import SimulateModelTool
16
from aiagents4pharma.talk2biomodels.tools.model_description import ModelDescriptionTool
17
from aiagents4pharma.talk2biomodels.tools.search_models import SearchModelsTool
18
from aiagents4pharma.talk2biomodels.tools.custom_plotter import CustomPlotterTool
19
from aiagents4pharma.talk2biomodels.tools.fetch_parameters import FetchParametersTool
20
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
21
from aiagents4pharma.talk2biomodels.tools.get_annotation import GetAnnotationTool
22
from langchain.agents import create_tool_calling_agent, AgentExecutor
23
from langchain_openai import ChatOpenAI
24
from langchain_core.messages import ChatMessage
25
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
26
27
# Set the streamlit session key for the sys bio model
28
ST_SYS_BIOMODEL_KEY = "last_model_object"
29
ST_SESSION_DF = "last_annotations_df"
30
31
st.set_page_config(page_title="Talk2BioModels", page_icon="🤖", layout="wide")
32
st.logo(image='./app/frontend/VPE.png', link="https://www.github.com/virtualpatientengine")
33
34
# Define tools and their metadata
35
simulate_model = SimulateModelTool(st_session_key=ST_SYS_BIOMODEL_KEY)
36
ask_question = AskQuestionTool(st_session_key=ST_SYS_BIOMODEL_KEY)
37
with open('./app/frontend/prompts/prompt_ask_question.txt', 'r', encoding='utf-8') as file:
38
    prompt_content = file.read()
39
ask_question.metadata = {
40
    "prompt": prompt_content
41
}
42
# plot_figure = PlotImageTool(st_session_key=ST_SYS_BIOMODEL_KEY)
43
model_description = ModelDescriptionTool(st_session_key=ST_SYS_BIOMODEL_KEY)
44
with open('./app/frontend/prompts/prompt_model_description.txt', 'r', encoding='utf-8') as file:
45
    prompt_content = file.read()
46
model_description.metadata = {
47
    "prompt": prompt_content
48
}
49
search_models = SearchModelsTool()
50
custom_plotter = CustomPlotterTool(st_session_key=ST_SYS_BIOMODEL_KEY)
51
fetch_parameters = FetchParametersTool(st_session_key=ST_SYS_BIOMODEL_KEY)
52
get_annotation = GetAnnotationTool(st_session_key=ST_SYS_BIOMODEL_KEY,
53
                                   st_session_df=ST_SESSION_DF)
54
55
tools = [simulate_model,
56
        ask_question,
57
        #  plot_figure,
58
        custom_plotter,
59
        fetch_parameters,
60
        model_description,
61
        search_models,
62
        get_annotation]
63
64
# Load the prompt for the main agent
65
with open('./app/frontend/prompts/prompt_general.txt', 'r', encoding='utf-8') as file:
66
    prompt_content = file.read()
67
68
# Create a chat prompt template
69
prompt = ChatPromptTemplate.from_messages([
70
        ("system", prompt_content),
71
        MessagesPlaceholder(variable_name='chat_history', optional=True),
72
        ("human", "{input}"),
73
        ("placeholder", "{agent_scratchpad}"),
74
])
75
76
# Initialize chat history
77
if "messages" not in st.session_state:
78
    st.session_state.messages = []
79
80
# Initialize the OpenAI model
81
llm = ChatOpenAI(temperature=0.0,
82
                model="gpt-4o-mini",
83
                streaming=True,
84
                api_key=os.getenv("OPENAI_API_KEY"))
85
86
# Create an agent
87
agent = create_tool_calling_agent(llm, tools, prompt)
88
89
# Create an agent executor
90
agent_executor = AgentExecutor(agent=agent,
91
                               tools=tools,
92
                               verbose=True,
93
                               return_intermediate_steps=True)
94
95
def render_plotly(df_simulation_results: pd.DataFrame) -> px.line:
96
    """
97
    Function to visualize the dataframe using Plotly.
98
99
    Args:
100
        df: pd.DataFrame: The input dataframe
101
    """
102
    df_simulation_results = df_simulation_results.melt(id_vars='Time',
103
                            var_name='Parameters',
104
                            value_name='Concentration')
105
    fig = px.line(df_simulation_results,
106
                    x='Time',
107
                    y='Concentration',
108
                    color='Parameters',
109
                    title="Concentration of parameters over time",
110
                    height=500,
111
                    width=600
112
            )
113
    return fig
114
115
def get_random_spinner_text():
116
    """
117
    Function to get a random spinner text.
118
    """
119
    spinner_texts = [
120
        "Your request is being carefully prepared. one moment, please.",
121
        "Working on that for you now—thanks for your patience.",
122
        "Hold tight! I’m getting that ready for you.",
123
        "I’m on it! Just a moment, please.",
124
        "Running algorithms... your answer is on its way.",
125
        "Processing your request. Please hold on...",
126
        "One moment while I work on that for you...",
127
        "Fetching the details for you. This won’t take long.",
128
        "Sit back while I take care of this for you."]
129
    return random.choice(spinner_texts)
130
131
# Main layout of the app split into two columns
132
main_col1, main_col2 = st.columns([3, 7])
133
# First column
134
with main_col1:
135
    with st.container(border=True):
136
        # Title
137
        st.write("""
138
            <h3 style='margin: 0px; padding-bottom: 10px; font-weight: bold;'>
139
            🤖 Talk2BioModels
140
            </h3>
141
            """,
142
            unsafe_allow_html=True)
143
144
        # LLM panel
145
        llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"]
146
        llm_option = st.selectbox(
147
            "Pick an LLM to power the agent",
148
            llms,
149
            index=0,
150
            key="st_selectbox_llm"
151
        )
152
153
        # Upload files
154
        uploaded_file = st.file_uploader(
155
            "Upload an XML/SBML file",
156
            accept_multiple_files=False,
157
            type=["xml", "sbml"],
158
            help='''Upload an XML/SBML file to simulate a biological model, \
159
                and ask questions about the simulation results.'''
160
            )
161
162
    with st.container(border=False, height=500):
163
        prompt = st.chat_input("Say something ...", key="st_chat_input")
164
165
# Second column
166
with main_col2:
167
    # Chat history panel
168
    with st.container(border=True, height=575):
169
        st.write("#### 💬 Chat History")
170
171
        # Display chat messages
172
        for count, message in enumerate(st.session_state.messages):
173
            if message["type"] == "message":
174
                with st.chat_message(message["content"].role,
175
                                     avatar="🤖" 
176
                                     if message["content"].role != 'user'
177
                                     else "👩🏻‍💻"):
178
                    st.markdown(message["content"].content)
179
                    st.empty()
180
            elif message["type"] == "plotly":
181
                st.plotly_chart(render_plotly(message["content"]),
182
                                use_container_width = True,
183
                                key=f"plotly_{count}")
184
            elif message["type"] == "dataframe":
185
                st.dataframe(message["content"],
186
                            use_container_width = True,
187
                            key=f"dataframe_{count}")
188
        if prompt:
189
            if ST_SYS_BIOMODEL_KEY not in st.session_state:
190
                st.session_state[ST_SYS_BIOMODEL_KEY] = None
191
192
            if ST_SESSION_DF not in st.session_state:
193
                st.session_state[ST_SESSION_DF] = None
194
195
            # Create a key 'uploaded_file' to read the uploaded file
196
            if uploaded_file:
197
                st.session_state.sbml_file_path = uploaded_file.read().decode("utf-8")
198
199
            # Display user prompt
200
            prompt_msg = ChatMessage(prompt, role="user")
201
            st.session_state.messages.append(
202
                {
203
                    "type": "message",
204
                    "content": prompt_msg
205
                }
206
            )
207
            with st.chat_message("user", avatar="👩🏻‍💻"):
208
                st.markdown(prompt)
209
                st.empty()
210
211
            with st.chat_message("assistant", avatar="🤖"):
212
            # with st.spinner("Fetching response ..."):
213
                with st.spinner(get_random_spinner_text()):
214
                    history = [(m["content"].role, m["content"].content)
215
                                        for m in st.session_state.messages
216
                                        if m["type"] == "message"]
217
                    chat_history = [
218
                        SystemMessage(content=m[1]) if m[0] == "system" else
219
                        HumanMessage(content=m[1]) if m[0] == "human" else
220
                        AIMessage(content=m[1])
221
                        for m in history
222
                    ]
223
                    # Call the agent
224
                    response = agent_executor.invoke({
225
                        "input": prompt,
226
                        "chat_history": chat_history
227
                    })
228
229
                    # Ensure response["output"] is a valid string
230
                    output_content = response.get("output", "")
231
232
                    # If output is a dictionary (like an error message), handle it properly
233
                    if isinstance(output_content, dict):
234
                        # Extract error message or default message
235
                        output_content = str(output_content.get('error', 'Unknown error occurred'))
236
237
                    # Add assistant response to chat history
238
                    assistant_msg = ChatMessage(content=output_content, role="assistant")
239
                    st.session_state.messages.append({
240
                        "type": "message",
241
                        "content": assistant_msg
242
                    })
243
                    
244
                    # Display the response
245
                    st.markdown(output_content)
246
                    st.empty()
247
                    print(response)
248
                    if "intermediate_steps" in response and len(response["intermediate_steps"]) > 0:
249
                        for r in response["intermediate_steps"]:
250
# Inside the agent_executor chain:
251
                                if r[0].tool == 'get_annotation':
252
                                    annotations_df = st.session_state[ST_SESSION_DF]
253
                                    # Display the DataFrame in Streamlit frontend
254
                                    st.dataframe(annotations_df, use_container_width=True)
255
                                    # Append the DataFrame to chat history (if necessary)
256
                                    st.session_state.messages.append({
257
                                        "type": "dataframe",
258
                                        "content": annotations_df
259
                                    })
260
261
                                elif r[0].tool == 'simulate_model':
262
                                    model_obj = st.session_state[ST_SYS_BIOMODEL_KEY]
263
                                    df_sim_results = model_obj.simulation_results
264
                                    # Add data to the chat history
265
                                    st.session_state.messages.append({
266
                                        "type": "dataframe",
267
                                        "content": df_sim_results
268
                                    })
269
                                    st.dataframe(df_sim_results, use_container_width=True)
270
                                    # Add the plotly chart to the chat history
271
                                    st.session_state.messages.append({
272
                                        "type": "plotly",
273
                                        "content": df_sim_results
274
                                    })
275
                                    # Display the plotly chart
276
                                    st.plotly_chart(render_plotly(df_sim_results), use_container_width=True)
277
278
                                elif r[0].tool == 'custom_plotter':
279
                                    model_obj = st.session_state[ST_SYS_BIOMODEL_KEY]
280
                                    # Prepare df_subset for custom_simulation_results
281
                                    df_subset = pd.DataFrame()
282
                                    if len(st.session_state.custom_simulation_results) > 0:
283
                                        custom_headers = st.session_state.custom_simulation_results
284
                                        custom_headers = list(custom_headers)
285
                                        # Add Time column to the custom headers
286
                                        if 'Time' not in custom_headers:
287
                                            custom_headers = ['Time'] + custom_headers
288
                                        
289
                                        # Make df_subset with only the custom headers
290
                                        df_subset = model_obj.simulation_results[custom_headers]
291
                                        # Add data to the chat history
292
                                        st.session_state.messages.append({
293
                                            "type": "dataframe",
294
                                            "content": df_subset
295
                                        })
296
                                        st.dataframe(df_subset, use_container_width=True)
297
                                        # Add the plotly chart to the chat history
298
                                        st.session_state.messages.append({
299
                                            "type": "plotly",
300
                                            "content": df_subset
301
                                        })
302
                                        # Display the plotly chart
303
                                        st.plotly_chart(render_plotly(df_subset), use_container_width=True)           
304
                    else:
305
                        # If intermediate_steps is empty, show a message
306
                        st.warning("No intermediate steps were found in the response.")
307