a b/app/frontend/streamlit_app_talk2scholars.py
1
#!/usr/bin/env python3
2
3
"""
4
Talk2Scholars: A Streamlit app for the Talk2Scholars graph.
5
"""
6
7
import os
8
import random
9
import sys
10
11
import hydra
12
import pandas as pd
13
import streamlit as st
14
from langchain.callbacks.tracers import LangChainTracer
15
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage, SystemMessage
16
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
from langchain_core.tracers.context import collect_runs
18
from langchain_openai import ChatOpenAI
19
from langsmith import Client
20
from langchain.callbacks.tracers import LangChainTracer
21
from streamlit_feedback import streamlit_feedback
22
from utils import streamlit_utils
23
24
sys.path.append("./")
25
# import get_app from main_agent
26
from aiagents4pharma.talk2scholars.agents.main_agent import get_app
27
28
# Initialize configuration
29
hydra.core.global_hydra.GlobalHydra.instance().clear()
30
if "config" not in st.session_state:
31
    # Load Hydra configuration
32
    with hydra.initialize(
33
        version_base=None,
34
        config_path="../../aiagents4pharma/talk2scholars/configs",
35
    ):
36
        cfg = hydra.compose(config_name="config", overrides=["app/frontend=default"])
37
        cfg = cfg.app.frontend
38
        st.session_state.config = cfg
39
else:
40
    cfg = st.session_state.config
41
42
st.set_page_config(
43
    page_title=cfg.page.title, page_icon=cfg.page.icon, layout=cfg.page.layout
44
)
45
# Set the logo
46
st.logo(
47
    image="docs/assets/VPE.png",
48
    size="large",
49
    link="https://github.com/VirtualPatientEngine",
50
)
51
52
53
# Check if env variables OPENAI_API_KEY and/or NVIDIA_API_KEY exist
54
if cfg.api_keys.openai_key not in os.environ:
55
    st.error(
56
        "Please set the OPENAI_API_KEY "
57
        "environment variables in the terminal where you run "
58
        "the app. For more information, please refer to our "
59
        "[documentation](https://virtualpatientengine.github.io/AIAgents4Pharma/#option-2-git)."
60
    )
61
    st.stop()
62
63
64
# Create a chat prompt template
65
prompt = ChatPromptTemplate.from_messages(
66
    [
67
        ("system", "Welcome to Talk2Scholars!"),
68
        MessagesPlaceholder(variable_name="chat_history", optional=True),
69
        ("human", "{input}"),
70
        ("placeholder", "{agent_scratchpad}"),
71
    ]
72
)
73
74
# Initialize chat history
75
if "messages" not in st.session_state:
76
    st.session_state.messages = []
77
78
# Initialize project_name for Langsmith
79
if "project_name" not in st.session_state:
80
    # st.session_state.project_name = str(st.session_state.user_name) + '@' + str(uuid.uuid4())
81
    st.session_state.project_name = "Talk2Scholars-" + str(random.randint(1000, 9999))
82
83
# Initialize run_id for Langsmith
84
if "run_id" not in st.session_state:
85
    st.session_state.run_id = None
86
87
# Initialize graph
88
if "unique_id" not in st.session_state:
89
    st.session_state.unique_id = random.randint(1, 1000)
90
if "app" not in st.session_state:
91
    if "llm_model" not in st.session_state:
92
        st.session_state.app = get_app(
93
            st.session_state.unique_id,
94
            llm_model=ChatOpenAI(model="gpt-4o-mini", temperature=0),
95
        )
96
    else:
97
        print(st.session_state.llm_model)
98
        st.session_state.app = get_app(
99
            st.session_state.unique_id,
100
            llm_model=streamlit_utils.get_base_chat_model(st.session_state.llm_model),
101
        )
102
# Get the app
103
app = st.session_state.app
104
105
106
def _submit_feedback(user_response):
107
    """
108
    Function to submit feedback to the developers.
109
    """
110
    client = Client()
111
    client.create_feedback(
112
        st.session_state.run_id,
113
        key="feedback",
114
        score=1 if user_response["score"] == "👍" else 0,
115
        comment=user_response["text"],
116
    )
117
    st.info("Your feedback is on its way to the developers. Thank you!", icon="🚀")
118
119
120
@st.fragment
121
def process_pdf_upload():
122
    """
123
    Process the uploaded PDF file automatically:
124
    Read the file as binary and store it in session state under "pdf_data".
125
    """
126
    pdf_file = st.file_uploader(
127
        "Upload an article",
128
        help="Upload an article in PDF format.",
129
        type=["pdf"],
130
        key="pdf_upload",
131
    )
132
133
    if pdf_file:
134
        import tempfile
135
136
        # print (pdf_file.name)
137
        with tempfile.NamedTemporaryFile(delete=False) as f:
138
            f.write(pdf_file.read())
139
            # print (f.name)
140
        st.session_state.pdf_data = {
141
            "pdf_object": f.name,  # binary formatted PDF
142
            "pdf_url": f.name,  # placeholder for URL if needed later
143
            "arxiv_id": None,  # placeholder for an arXiv id if applicable
144
        }
145
        # Create config for the agent
146
        config = {"configurable": {"thread_id": st.session_state.unique_id}}
147
        # Update the agent state with the selected LLM model
148
        app.update_state(config, {"pdf_data": st.session_state.pdf_data})
149
150
151
# Main layout of the app split into two columns
152
main_col1, main_col2 = st.columns([3, 7])
153
# First column
154
with main_col1:
155
    with st.container(border=True):
156
        # Title
157
        st.write(
158
            """
159
            <h3 style='margin: 0px; padding-bottom: 10px; font-weight: bold;'>
160
            🤖 Talk2Scholars
161
            </h3>
162
            """,
163
            unsafe_allow_html=True,
164
        )
165
166
        # LLM model panel
167
        st.selectbox(
168
            "Pick an LLM to power the agent",
169
            list(cfg.llms.available_models),
170
            index=0,
171
            key="llm_model",
172
            on_change=streamlit_utils.update_llm_model,
173
            help="Used for tool calling and generating responses.",
174
        )
175
176
        # Text embedding model panel
177
        text_models = [
178
            "OpenAI/text-embedding-ada-002",
179
            "NVIDIA/llama-3.2-nv-embedqa-1b-v2",
180
        ]
181
        st.selectbox(
182
            "Pick a text embedding model",
183
            text_models,
184
            index=0,
185
            key="text_embedding_model",
186
            on_change=streamlit_utils.update_text_embedding_model,
187
            kwargs={"app": app},
188
            help="Used for Retrival Augmented Generation (RAG)",
189
        )
190
191
        # Upload files (placeholder)
192
        process_pdf_upload()
193
194
    with st.container(border=False, height=500):
195
        prompt = st.chat_input("Say something ...", key="st_chat_input")
196
197
# Second column
198
with main_col2:
199
    # Chat history panel
200
    with st.container(border=True, height=775):
201
        st.write("#### 💬 Chat History")
202
203
        # Display chat messages
204
        for count, message in enumerate(st.session_state.messages):
205
            if message["type"] == "message":
206
                with st.chat_message(
207
                    message["content"].role,
208
                    avatar="🤖" if message["content"].role != "user" else "👩🏻‍💻",
209
                ):
210
                    st.markdown(message["content"].content)
211
                    st.empty()
212
            elif message["type"] == "button":
213
                if st.button(message["content"], key=message["key"]):
214
                    # Trigger the question
215
                    prompt = message["question"]
216
                    st.empty()
217
            elif message["type"] == "dataframe":
218
                if "tool_name" in message:
219
                    if message["tool_name"] in [
220
                        "display_results",
221
                    ]:
222
                        df_papers = message["content"]
223
                        st.dataframe(
224
                            df_papers,
225
                            use_container_width=True,
226
                            key=message["key"],
227
                            hide_index=True,
228
                            column_config={
229
                                "URL": st.column_config.LinkColumn(
230
                                    display_text="Open",
231
                                ),
232
                            },
233
                        )
234
                # else:
235
                #     streamlit_utils.render_table(message["content"],
236
                #                     key=message["key"],
237
                #                     # tool_name=message["tool_name"],
238
                #                     save_table=False)
239
                st.empty()
240
        # Display intro message only the first time
241
        # i.e. when there are no messages in the chat
242
        if not st.session_state.messages:
243
            with st.chat_message("assistant", avatar="🤖"):
244
                with st.spinner("Initializing the agent ..."):
245
                    config = {"configurable": {"thread_id": st.session_state.unique_id}}
246
                    # Update the agent state with the selected LLM model
247
                    current_state = app.get_state(config)
248
                    app.update_state(
249
                        config,
250
                        {
251
                            "llm_model": streamlit_utils.get_base_chat_model(
252
                                st.session_state.llm_model
253
                            )
254
                        },
255
                    )
256
                    intro_prompt = "Greet and tell your name and about yourself."
257
                    intro_prompt += " Also, tell about the agents you can access and ther short description."
258
                    intro_prompt += " We have provided starter questions (separately) outisde your response."
259
                    intro_prompt += " Do not provide any questions by yourself. Let the users know that they can"
260
                    intro_prompt += " simply click on the questions to execute them."
261
                    # intro_prompt += " Let them know that they can check out the use cases"
262
                    # intro_prompt += " and FAQs described in the link below. Be friendly and helpful."
263
                    # intro_prompt += "\n"
264
                    # intro_prompt += "Here is the link to the use cases: [Use Cases](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/cases/Case_1/)"
265
                    # intro_prompt += "\n"
266
                    # intro_prompt += "Here is the link to the FAQs: [FAQs](https://virtualpatientengine.github.io/AIAgents4Pharma/talk2biomodels/faq/)"
267
                    response = app.stream(
268
                        {"messages": [HumanMessage(content=intro_prompt)]},
269
                        config=config,
270
                        stream_mode="messages",
271
                    )
272
                    st.write_stream(streamlit_utils.stream_response(response))
273
                    current_state = app.get_state(config)
274
                    # Add response to chat history
275
                    assistant_msg = ChatMessage(
276
                        current_state.values["messages"][-1].content, role="assistant"
277
                    )
278
                    st.session_state.messages.append(
279
                        {"type": "message", "content": assistant_msg}
280
                    )
281
                    st.empty()
282
        if len(st.session_state.messages) <= 1:
283
            for count, question in enumerate(streamlit_utils.sample_questions_t2s()):
284
                if st.button(
285
                    f"Q{count+1}. {question}", key=f"sample_question_{count+1}"
286
                ):
287
                    # Trigger the question
288
                    prompt = question
289
                # Add button click to chat history
290
                st.session_state.messages.append(
291
                    {
292
                        "type": "button",
293
                        "question": question,
294
                        "content": f"Q{count+1}. {question}",
295
                        "key": f"sample_question_{count+1}",
296
                    }
297
                )
298
299
        # When the user asks a question
300
        if prompt:
301
            # Create a key 'uploaded_file' to read the uploaded file
302
            # if uploaded_file:
303
            #     st.session_state.article_pdf = uploaded_file.read().decode("utf-8")
304
305
            # Display user prompt
306
            prompt_msg = ChatMessage(prompt, role="user")
307
            st.session_state.messages.append({"type": "message", "content": prompt_msg})
308
            with st.chat_message("user", avatar="👩🏻‍💻"):
309
                st.markdown(prompt)
310
                st.empty()
311
312
            with st.chat_message("assistant", avatar="🤖"):
313
                # with st.spinner("Fetching response ..."):
314
                with st.spinner():
315
                    # Get chat history
316
                    history = [
317
                        (m["content"].role, m["content"].content)
318
                        for m in st.session_state.messages
319
                        if m["type"] == "message"
320
                    ]
321
                    # Convert chat history to ChatMessage objects
322
                    chat_history = [
323
                        (
324
                            SystemMessage(content=m[1])
325
                            if m[0] == "system"
326
                            else (
327
                                HumanMessage(content=m[1])
328
                                if m[0] == "human"
329
                                else AIMessage(content=m[1])
330
                            )
331
                        )
332
                        for m in history
333
                    ]
334
335
                    # # Create config for the agent
336
                    config = {"configurable": {"thread_id": st.session_state.unique_id}}
337
                    # Update the LLM model
338
                    app.update_state(
339
                        config,
340
                        {
341
                            "llm_model": streamlit_utils.get_base_chat_model(
342
                                st.session_state.llm_model
343
                            ),
344
                            "text_embedding_model": streamlit_utils.get_text_embedding_model(
345
                                st.session_state.text_embedding_model
346
                            ),
347
                        },
348
                    )
349
                    current_state = app.get_state(config)
350
                    print("PDF_DATA", len(current_state.values["pdf_data"]))
351
352
                    streamlit_utils.get_response("T2S", None, app, st, prompt)
353
354
                    # # Create config for the agent
355
                    # config = {"configurable": {"thread_id": st.session_state.unique_id}}
356
                    # # Update the LLM model
357
                    # app.update_state(
358
                    #     config,
359
                    #     {
360
                    #         "llm_model": streamlit_utils.get_base_chat_model(
361
                    #             st.session_state.llm_model
362
                    #         )
363
                    #     },
364
                    # )
365
                    # # Update the agent state with the selected LLM model
366
                    # current_state = app.get_state(config)
367
368
                    # with collect_runs() as cb:
369
                    #     # Add Langsmith tracer
370
                    #     tracer = LangChainTracer(
371
                    #         project_name=st.session_state.project_name
372
                    #     )
373
374
                    #     # Get response from the agent with Langsmith tracing enabled
375
                    #     # response = app.invoke(
376
                    #     #     {"messages": [HumanMessage(content=prompt)]},
377
                    #     #     config=config | {"callbacks": [tracer]},
378
                    #     # )
379
380
                    #     response = app.stream(
381
                    #         {"messages": [HumanMessage(content=prompt)]},
382
                    #         config=config|{"callbacks": [tracer]},
383
                    #         stream_mode="messages"
384
                    #     )
385
                    #     st.write_stream(streamlit_utils.stream_response(response))
386
387
                    #     # Assign the traced run ID to session state
388
                    #     if cb.traced_runs:
389
                    #         st.session_state.run_id = cb.traced_runs[-1].id
390
391
                    # # # Get the latest agent state after the response
392
                    # # current_state = app.get_state(config)
393
                    # #
394
                    # # response = app.invoke(
395
                    # #     {"messages": [HumanMessage(content=prompt)]},
396
                    # #     config=config,
397
                    # # )
398
399
                    # current_state = app.get_state(config)
400
401
                    # # print (response["messages"])
402
403
                    # # Add assistant response to chat history
404
                    # assistant_msg = ChatMessage(
405
                    #     response["messages"][-1].content, role="assistant"
406
                    # )
407
                    # st.session_state.messages.append(
408
                    #     {"type": "message", "content": assistant_msg}
409
                    # )
410
                    # # Display the response in the chat
411
                    # st.markdown(response["messages"][-1].content)
412
                    # st.empty()
413
                    # reversed_messages = current_state.values["messages"][::-1]
414
                    # # Loop through the reversed messages until a
415
                    # # HumanMessage is found i.e. the last message
416
                    # # from the user. This is to display the results
417
                    # # of the tool calls made by the agent since the
418
                    # # last message from the user.
419
                    # for msg in reversed_messages:
420
                    #     # print (msg)
421
                    #     # Break the loop if the message is a HumanMessage
422
                    #     # i.e. the last message from the user
423
                    #     if isinstance(msg, HumanMessage):
424
                    #         break
425
                    #     # Skip the message if it is an AIMessage
426
                    #     # i.e. a message from the agent. An agent
427
                    #     # may make multiple tool calls before the
428
                    #     # final response to the user.
429
                    #     if isinstance(msg, AIMessage):
430
                    #         # print ('AIMessage', msg)
431
                    #         continue
432
                    #     # Work on the message if it is a ToolMessage
433
                    #     # These may contain additional visuals that
434
                    #     # need to be displayed to the user.
435
                    #     # print("ToolMessage", msg)
436
                    #     # Skip the Tool message if it is an error message
437
                    #     if msg.status == "error":
438
                    #         continue
439
                    #     # print("ToolMessage", msg)
440
                    #     uniq_msg_id = "_".join(
441
                    #         [msg.name, msg.tool_call_id, str(st.session_state.run_id)]
442
                    #     )
443
                    # if msg.name in ['search_tool',
444
                    #                 'get_single_paper_recommendations',
445
                    #                 'get_multi_paper_recommendations']:
446
                    # if msg.name in ["display_results"]:
447
                    #     # Display the results of the tool call
448
                    #     # for msg_artifact in msg.artifact:
449
                    #     # dic_papers = msg.artifact[msg_artifact]
450
                    #     dic_papers = msg.artifact
451
                    #     if not dic_papers:
452
                    #         continue
453
                    #     df_papers = pd.DataFrame.from_dict(
454
                    #         dic_papers, orient="index"
455
                    #     )
456
                    #     # Add index as a column "key"
457
                    #     df_papers["Key"] = df_papers.index
458
                    #     # Drop index
459
                    #     df_papers.reset_index(drop=True, inplace=True)
460
                    #     # Drop colum abstract
461
                    #     df_papers.drop(columns=["Abstract", "Key"], inplace=True)
462
463
                    #     if "Year" in df_papers.columns:
464
                    #         df_papers["Year"] = df_papers["Year"].apply(
465
                    #             lambda x: (
466
                    #                 str(int(x))
467
                    #                 if pd.notna(x) and str(x).isdigit()
468
                    #                 else None
469
                    #             )
470
                    #         )
471
472
                    #     if "Date" in df_papers.columns:
473
                    #         df_papers["Date"] = df_papers["Date"].apply(
474
                    #             lambda x: (
475
                    #                 pd.to_datetime(x, errors="coerce").strftime(
476
                    #                     "%Y-%m-%d"
477
                    #                 )
478
                    #                 if pd.notna(pd.to_datetime(x, errors="coerce"))
479
                    #                 else None
480
                    #             )
481
                    #         )
482
483
                    #     st.dataframe(
484
                    #         df_papers,
485
                    #         hide_index=True,
486
                    #         column_config={
487
                    #             "URL": st.column_config.LinkColumn(
488
                    #                 display_text="Open",
489
                    #             ),
490
                    #         },
491
                    #     )
492
                    #     # Add data to the chat history
493
                    #     st.session_state.messages.append(
494
                    #         {
495
                    #             "type": "dataframe",
496
                    #             "content": df_papers,
497
                    #             "key": "dataframe_" + uniq_msg_id,
498
                    #             "tool_name": msg.name,
499
                    #         }
500
                    #     )
501
                    #     st.empty()
502
        # Collect feedback and display the thumbs feedback
503
        if st.session_state.get("run_id"):
504
            feedback = streamlit_feedback(
505
                feedback_type="thumbs",
506
                optional_text_label="[Optional] Please provide an explanation",
507
                on_submit=_submit_feedback,
508
                key=f"feedback_{st.session_state.run_id}",
509
            )