Switch to unified view

a b/app/frontend/streamlit_app_talk2knowledgegraphs.py
1
#!/usr/bin/env python3
2
3
"""
4
Talk2KnowledgeGraphs: A Streamlit app for the Talk2KnowledgeGraphs graph.
5
"""
6
7
import os
8
import sys
9
import random
10
import streamlit as st
11
import pandas as pd
12
import hydra
13
from streamlit_feedback import streamlit_feedback
14
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
15
from langchain_core.messages import ChatMessage
16
from langchain_core.tracers.context import collect_runs
17
from langchain.callbacks.tracers import LangChainTracer
18
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
19
from langchain_ollama import OllamaEmbeddings, ChatOllama
20
from utils import streamlit_utils
21
22
sys.path.append("./")
23
from aiagents4pharma.talk2knowledgegraphs.agents.t2kg_agent import get_app
24
# from talk2knowledgegraphs.agents.t2kg_agent import get_app
25
26
st.set_page_config(
27
    page_title="Talk2KnowledgeGraphs",
28
    page_icon="🤖",
29
    layout="wide",
30
    initial_sidebar_state="collapsed",
31
)
32
33
# Initialize configuration
34
hydra.core.global_hydra.GlobalHydra.instance().clear()
35
if "config" not in st.session_state:
36
    # Load Hydra configuration
37
    with hydra.initialize(
38
        version_base=None,
39
        config_path="../../aiagents4pharma/talk2knowledgegraphs/configs",
40
    ):
41
        cfg = hydra.compose(config_name="config", overrides=["app/frontend=default"])
42
        cfg = cfg.app.frontend
43
        st.session_state.config = cfg
44
else:
45
    cfg = st.session_state.config
46
47
48
# st.logo(
49
#     image='docs/VPE.png',
50
#     size='large',
51
#     link='https://github.com/VirtualPatientEngine'
52
# )
53
54
# Check if env variable OPENAI_API_KEY exists
55
if "OPENAI_API_KEY" not in os.environ:
56
    st.error(
57
        "Please set the OPENAI_API_KEY environment \
58
        variable in the terminal where you run the app."
59
    )
60
    st.stop()
61
62
# Initialize current user
63
if "current_user" not in st.session_state:
64
    st.session_state.current_user = cfg.default_user
65
66
# Initialize chat history
67
if "messages" not in st.session_state:
68
    st.session_state.messages = []
69
70
# Initialize session state for SBML file uploader
71
# if "sbml_key" not in st.session_state:
72
#     st.session_state.sbml_key = 0
73
74
# Initialize session state for pre-clinical data package uploader
75
if "data_package_key" not in st.session_state:
76
    st.session_state.data_package_key = 0
77
78
# Initialize session state for patient gene expression data uploader
79
if "endotype_key" not in st.session_state:
80
    st.session_state.endotype_key = 0
81
82
# Initialize session state for uploaded files
83
if "uploaded_files" not in st.session_state:
84
    st.session_state.uploaded_files = []
85
86
    # Make directories if not exists
87
    os.makedirs(cfg.upload_data_dir, exist_ok=True)
88
89
# Initialize project_name for Langsmith
90
if "project_name" not in st.session_state:
91
    # st.session_state.project_name = str(st.session_state.user_name) + '@' + str(uuid.uuid4())
92
    st.session_state.project_name = "T2KG-" + str(random.randint(1000, 9999))
93
94
# Initialize run_id for Langsmith
95
if "run_id" not in st.session_state:
96
    st.session_state.run_id = None
97
98
# Initialize graph
99
if "unique_id" not in st.session_state:
100
    st.session_state.unique_id = random.randint(1, 1000)
101
102
# Initialize the LLM model
103
if "llm_model" not in st.session_state:
104
    st.session_state.llm_model = tuple(cfg.openai_llms + cfg.ollama_llms)[0]
105
106
# Initialize the app with default LLM model for the first time
107
if "app" not in st.session_state:
108
    # Initialize the app
109
    if st.session_state.llm_model in cfg.openai_llms:
110
        print("Using OpenAI LLM model")
111
        st.session_state.app = get_app(
112
            st.session_state.unique_id,
113
            llm_model=ChatOpenAI(
114
                model=st.session_state.llm_model, temperature=cfg.temperature
115
            ),
116
        )
117
    else:
118
        print("Using Ollama LLM model")
119
        st.session_state.app = get_app(
120
            st.session_state.unique_id,
121
            llm_model=ChatOllama(
122
                model=st.session_state.llm_model, temperature=cfg.temperature
123
            ),
124
        )
125
126
if "topk_nodes" not in st.session_state:
127
    # Subgraph extraction settings
128
    st.session_state.topk_nodes = cfg.reasoning_subgraph_topk_nodes
129
    st.session_state.topk_edges = cfg.reasoning_subgraph_topk_edges
130
131
# Get the app
132
app = st.session_state.app
133
134
# Apply custom CSS
135
streamlit_utils.apply_css()
136
137
# Sidebar
138
with st.sidebar:
139
    st.markdown("**âš™ī¸ Subgraph Extraction Settings**")
140
    topk_nodes = st.slider(
141
        "Top-K (Nodes)",
142
        cfg.reasoning_subgraph_topk_nodes_min,
143
        cfg.reasoning_subgraph_topk_nodes_max,
144
        st.session_state.topk_nodes,
145
        key="st_slider_topk_nodes",
146
    )
147
    st.session_state.topk_nodes = topk_nodes
148
    topk_edges = st.slider(
149
        "Top-K (Edges)",
150
        cfg.reasoning_subgraph_topk_nodes_min,
151
        cfg.reasoning_subgraph_topk_nodes_max,
152
        st.session_state.topk_edges,
153
        key="st_slider_topk_edges",
154
    )
155
    st.session_state.topk_edges = topk_edges
156
157
# Main layout of the app split into two columns
158
main_col1, main_col2 = st.columns([3, 7])
159
# First column
160
with main_col1:
161
    with st.container(border=True):
162
        # Title
163
        st.write(
164
            """
165
            <h3 style='margin: 0px; padding-bottom: 10px; font-weight: bold;'>
166
            🤖 Talk2KnowledgeGraphs
167
            </h3>
168
            """,
169
            unsafe_allow_html=True,
170
        )
171
172
        # LLM panel (Only at the front-end for now)
173
        # llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"]
174
        llms = tuple(cfg.openai_llms + cfg.ollama_llms)
175
        st.selectbox(
176
            "Pick an LLM to power the agent",
177
            llms,
178
            index=0,
179
            key="llm_model",
180
            on_change=streamlit_utils.update_llm_model,
181
        )
182
183
        # Upload files
184
        streamlit_utils.get_uploaded_files(cfg)
185
186
        # Help text
187
        # st.button("Know more ↗",
188
        #         #   icon="â„šī¸",
189
        #           on_click=streamlit_utils.help_button,
190
        #           use_container_width=False)
191
192
    with st.container(border=False, height=500):
193
        prompt = st.chat_input("Say something ...", key="st_chat_input")
194
195
# Second column
196
with main_col2:
197
    # Chat history panel
198
    with st.container(border=True, height=575):
199
        st.write("#### đŸ’Ŧ Chat History")
200
201
        # Display chat messages
202
        for count, message in enumerate(st.session_state.messages):
203
            if message["type"] == "message":
204
                with st.chat_message(
205
                    message["content"].role,
206
                    avatar="🤖" if message["content"].role != "user" else "👩đŸģ‍đŸ’ģ",
207
                ):
208
                    st.markdown(message["content"].content)
209
                    st.empty()
210
            elif message["type"] == "plotly":
211
                streamlit_utils.render_plotly(
212
                    message["content"],
213
                    key=message["key"],
214
                    title=message["title"],
215
                    #   tool_name=message["tool_name"],
216
                    save_chart=False,
217
                )
218
                st.empty()
219
            elif message["type"] == "toggle":
220
                streamlit_utils.render_toggle(
221
                    key=message["key"],
222
                    toggle_text=message["content"],
223
                    toggle_state=message["toggle_state"],
224
                    save_toggle=False,
225
                )
226
                st.empty()
227
            elif message["type"] == "dataframe":
228
                streamlit_utils.render_table(
229
                    message["content"],
230
                    key=message["key"],
231
                    # tool_name=message["tool_name"],
232
                    save_table=False,
233
                )
234
                st.empty()
235
            elif message["type"] == "graph":
236
                streamlit_utils.render_graph(
237
                    message["content"], key=message["key"], save_graph=False
238
                )
239
                st.empty()
240
241
        # When the user asks a question
242
        if prompt:
243
            # Display user prompt
244
            prompt_msg = ChatMessage(prompt, role="user")
245
            st.session_state.messages.append({"type": "message", "content": prompt_msg})
246
            with st.chat_message("user", avatar="👩đŸģ‍đŸ’ģ"):
247
                st.markdown(prompt)
248
                st.empty()
249
250
            # Auxiliary visualization-related variables
251
            graphs_visuals = []
252
            with st.chat_message("assistant", avatar="🤖"):
253
                # with st.spinner("Fetching response ..."):
254
                with st.spinner():
255
                    # Get chat history
256
                    history = [
257
                        (m["content"].role, m["content"].content)
258
                        for m in st.session_state.messages
259
                        if m["type"] == "message"
260
                    ]
261
                    # Convert chat history to ChatMessage objects
262
                    chat_history = [
263
                        SystemMessage(content=m[1])
264
                        if m[0] == "system"
265
                        else HumanMessage(content=m[1])
266
                        if m[0] == "human"
267
                        else AIMessage(content=m[1])
268
                        for m in history
269
                    ]
270
271
                    # Prepare LLM and embedding model for updating the agent
272
                    if st.session_state.llm_model in cfg.openai_llms:
273
                        llm_model = ChatOpenAI(
274
                            model=st.session_state.llm_model,
275
                            temperature=cfg.temperature,
276
                        )
277
                        emb_model = OpenAIEmbeddings(model=cfg.openai_embeddings[0])
278
                    else:
279
                        llm_model = ChatOllama(
280
                            model=st.session_state.llm_model,
281
                            temperature=cfg.temperature,
282
                        )
283
                        emb_model = OllamaEmbeddings(model=cfg.ollama_embeddings[0])
284
285
                    # Create config for the agent
286
                    config = {"configurable": {"thread_id": st.session_state.unique_id}}
287
                    app.update_state(
288
                        config,
289
                        {
290
                            "llm_model": llm_model,
291
                            "embedding_model": emb_model,
292
                            "uploaded_files": st.session_state.uploaded_files,
293
                            "topk_nodes": st.session_state.topk_nodes,
294
                            "topk_edges": st.session_state.topk_edges,
295
                            "dic_source_graph": [
296
                                {
297
                                    "name": st.session_state.config["kg_name"],
298
                                    "kg_pyg_path": st.session_state.config["kg_pyg_path"],
299
                                    "kg_text_path": st.session_state.config["kg_text_path"],
300
                                }
301
                            ],
302
                        },
303
                    )
304
305
                    # Update the agent states
306
                    current_state = app.get_state(config)
307
308
                    ERROR_FLAG = False
309
                    with collect_runs() as cb:
310
                        # Add Langsmith tracer
311
                        tracer = LangChainTracer(
312
                            project_name=st.session_state.project_name
313
                        )
314
                        # Get response from the agent
315
                        response = app.invoke(
316
                            {"messages": [HumanMessage(content=prompt)]},
317
                            config=config | {"callbacks": [tracer]},
318
                        )
319
                        st.session_state.run_id = cb.traced_runs[-1].id
320
                    current_state = app.get_state(config)
321
322
                    # Add response to chat history
323
                    assistant_msg = ChatMessage(
324
                        response["messages"][-1].content, role="assistant"
325
                    )
326
                    st.session_state.messages.append(
327
                        {"type": "message", "content": assistant_msg}
328
                    )
329
                    # Display the response in the chat
330
                    st.markdown(response["messages"][-1].content)
331
                    st.empty()
332
333
                    # Get the current state of the graph
334
                    current_state = app.get_state(config)
335
336
                    # # Get the messages from the current state
337
                    # # and reverse the order
338
                    reversed_messages = current_state.values["messages"][::-1]
339
340
                    # Loop through the reversed messages until a
341
                    # HumanMessage is found i.e. the last message
342
                    # from the user. This is to display the results
343
                    # of the tool calls made by the agent since the
344
                    # last message from the user.
345
                    for msg in reversed_messages:
346
                        # print (msg)
347
                        # Break the loop if the message is a HumanMessage
348
                        # i.e. the last message from the user
349
                        if isinstance(msg, HumanMessage):
350
                            break
351
                        # Skip the message if it is an AIMessage
352
                        # i.e. a message from the agent. An agent
353
                        # may make multiple tool calls before the
354
                        # final response to the user.
355
                        if isinstance(msg, AIMessage):
356
                            continue
357
                        # Work on the message if it is a ToolMessage
358
                        # These may contain additional visuals that
359
                        # need to be displayed to the user.
360
                        # print("ToolMessage", msg)
361
                        # Skip the Tool message if it is an error message
362
                        if msg.status == "error":
363
                            continue
364
365
                        # Create a unique message id to identify the tool call
366
                        # msg.name is the name of the tool
367
                        # msg.tool_call_id is the unique id of the tool call
368
                        # st.session_state.run_id is the unique id of the run
369
                        uniq_msg_id = (
370
                            msg.name
371
                            + "_"
372
                            + msg.tool_call_id
373
                            + "_"
374
                            + str(st.session_state.run_id)
375
                        )
376
                        if msg.name in ["subgraph_extraction"]:
377
                            print(
378
                                "-",
379
                                len(current_state.values["dic_extracted_graph"]),
380
                                "subgraph_extraction",
381
                            )
382
                            # Add the graph into the visuals list
383
                            latest_graph = current_state.values["dic_extracted_graph"][
384
                                -1
385
                            ]
386
                            if current_state.values["dic_extracted_graph"]:
387
                                graphs_visuals.append(
388
                                    {
389
                                        "content": latest_graph["graph_dict"],
390
                                        "key": "subgraph_" + uniq_msg_id,
391
                                    }
392
                                )
393
394
            # Visualize the graph
395
            if len(graphs_visuals) > 0:
396
                for count, graph in enumerate(graphs_visuals):
397
                    streamlit_utils.render_graph(
398
                        graph_dict=graph["content"], key=graph["key"], save_graph=True
399
                    )
400
401
        # Collect feedback and display the thumbs feedback
402
        if st.session_state.get("run_id"):
403
            feedback = streamlit_feedback(
404
                feedback_type="thumbs",
405
                optional_text_label="[Optional] Please provide an explanation",
406
                on_submit=streamlit_utils.submit_feedback,
407
                key=f"feedback_{st.session_state.run_id}",
408
            )