Switch to unified view

a b/aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py
1
'''
2
This is the agent file for the Talk2KnowledgeGraphs agent.
3
'''
4
5
import logging
6
from typing import Annotated
7
import hydra
8
from langchain_core.language_models.chat_models import BaseChatModel
9
from langgraph.checkpoint.memory import MemorySaver
10
from langgraph.graph import START, StateGraph
11
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
12
from ..tools.subgraph_extraction import SubgraphExtractionTool
13
from ..tools.subgraph_summarization import SubgraphSummarizationTool
14
from ..tools.graphrag_reasoning import GraphRAGReasoningTool
15
from ..states.state_talk2knowledgegraphs import Talk2KnowledgeGraphs
16
17
# Initialize logger
18
logging.basicConfig(level=logging.INFO)
19
logger = logging.getLogger(__name__)
20
21
def get_app(uniq_id, llm_model: BaseChatModel):
22
    '''
23
    This function returns the langraph app.
24
    '''
25
    def agent_t2kg_node(state: Annotated[dict, InjectedState]):
26
        '''
27
        This function calls the model.
28
        '''
29
        logger.log(logging.INFO, "Calling t2kg_agent node with thread_id %s", uniq_id)
30
        response = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
31
32
        return response
33
34
    # Load hydra configuration
35
    logger.log(logging.INFO, "Load Hydra configuration for Talk2KnowledgeGraphs agent.")
36
    with hydra.initialize(version_base=None, config_path="../configs"):
37
        cfg = hydra.compose(config_name='config',
38
                            overrides=['agents/t2kg_agent=default'])
39
        cfg = cfg.agents.t2kg_agent
40
41
    # Define the tools
42
    subgraph_extraction = SubgraphExtractionTool()
43
    subgraph_summarization = SubgraphSummarizationTool()
44
    graphrag_reasoning = GraphRAGReasoningTool()
45
    tools = ToolNode([
46
                    subgraph_extraction,
47
                    subgraph_summarization,
48
                    graphrag_reasoning,
49
                    ])
50
51
    # Create the agent
52
    model = create_react_agent(
53
                llm_model,
54
                tools=tools,
55
                state_schema=Talk2KnowledgeGraphs,
56
                prompt=cfg.state_modifier,
57
                version='v2',
58
                checkpointer=MemorySaver()
59
            )
60
61
    # Define a new graph
62
    workflow = StateGraph(Talk2KnowledgeGraphs)
63
64
    # Define the two nodes we will cycle between
65
    workflow.add_node("agent_t2kg", agent_t2kg_node)
66
67
    # Set the entrypoint as the first node
68
    # This means that this node is the first one called
69
    workflow.add_edge(START, "agent_t2kg")
70
71
    # Initialize memory to persist state between graph runs
72
    checkpointer = MemorySaver()
73
74
    # Finally, we compile it!
75
    # This compiles it into a LangChain Runnable,
76
    # meaning you can use it as you would any other runnable.
77
    # Note that we're (optionally) passing the memory
78
    # when compiling the graph
79
    app = workflow.compile(checkpointer=checkpointer,
80
                           name="T2KG_Agent")
81
    logger.log(logging.INFO,
82
               "Compiled the graph with thread_id %s and llm_model %s",
83
               uniq_id,
84
               llm_model)
85
86
    return app