|
a |
|
b/aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py |
|
|
1 |
''' |
|
|
2 |
This is the agent file for the Talk2KnowledgeGraphs agent. |
|
|
3 |
''' |
|
|
4 |
|
|
|
5 |
import logging |
|
|
6 |
from typing import Annotated |
|
|
7 |
import hydra |
|
|
8 |
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
9 |
from langgraph.checkpoint.memory import MemorySaver |
|
|
10 |
from langgraph.graph import START, StateGraph |
|
|
11 |
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState |
|
|
12 |
from ..tools.subgraph_extraction import SubgraphExtractionTool |
|
|
13 |
from ..tools.subgraph_summarization import SubgraphSummarizationTool |
|
|
14 |
from ..tools.graphrag_reasoning import GraphRAGReasoningTool |
|
|
15 |
from ..states.state_talk2knowledgegraphs import Talk2KnowledgeGraphs |
|
|
16 |
|
|
|
17 |
# Initialize logger |
|
|
18 |
logging.basicConfig(level=logging.INFO) |
|
|
19 |
logger = logging.getLogger(__name__) |
|
|
20 |
|
|
|
21 |
def get_app(uniq_id, llm_model: BaseChatModel): |
|
|
22 |
''' |
|
|
23 |
This function returns the langraph app. |
|
|
24 |
''' |
|
|
25 |
def agent_t2kg_node(state: Annotated[dict, InjectedState]): |
|
|
26 |
''' |
|
|
27 |
This function calls the model. |
|
|
28 |
''' |
|
|
29 |
logger.log(logging.INFO, "Calling t2kg_agent node with thread_id %s", uniq_id) |
|
|
30 |
response = model.invoke(state, {"configurable": {"thread_id": uniq_id}}) |
|
|
31 |
|
|
|
32 |
return response |
|
|
33 |
|
|
|
34 |
# Load hydra configuration |
|
|
35 |
logger.log(logging.INFO, "Load Hydra configuration for Talk2KnowledgeGraphs agent.") |
|
|
36 |
with hydra.initialize(version_base=None, config_path="../configs"): |
|
|
37 |
cfg = hydra.compose(config_name='config', |
|
|
38 |
overrides=['agents/t2kg_agent=default']) |
|
|
39 |
cfg = cfg.agents.t2kg_agent |
|
|
40 |
|
|
|
41 |
# Define the tools |
|
|
42 |
subgraph_extraction = SubgraphExtractionTool() |
|
|
43 |
subgraph_summarization = SubgraphSummarizationTool() |
|
|
44 |
graphrag_reasoning = GraphRAGReasoningTool() |
|
|
45 |
tools = ToolNode([ |
|
|
46 |
subgraph_extraction, |
|
|
47 |
subgraph_summarization, |
|
|
48 |
graphrag_reasoning, |
|
|
49 |
]) |
|
|
50 |
|
|
|
51 |
# Create the agent |
|
|
52 |
model = create_react_agent( |
|
|
53 |
llm_model, |
|
|
54 |
tools=tools, |
|
|
55 |
state_schema=Talk2KnowledgeGraphs, |
|
|
56 |
prompt=cfg.state_modifier, |
|
|
57 |
version='v2', |
|
|
58 |
checkpointer=MemorySaver() |
|
|
59 |
) |
|
|
60 |
|
|
|
61 |
# Define a new graph |
|
|
62 |
workflow = StateGraph(Talk2KnowledgeGraphs) |
|
|
63 |
|
|
|
64 |
# Define the two nodes we will cycle between |
|
|
65 |
workflow.add_node("agent_t2kg", agent_t2kg_node) |
|
|
66 |
|
|
|
67 |
# Set the entrypoint as the first node |
|
|
68 |
# This means that this node is the first one called |
|
|
69 |
workflow.add_edge(START, "agent_t2kg") |
|
|
70 |
|
|
|
71 |
# Initialize memory to persist state between graph runs |
|
|
72 |
checkpointer = MemorySaver() |
|
|
73 |
|
|
|
74 |
# Finally, we compile it! |
|
|
75 |
# This compiles it into a LangChain Runnable, |
|
|
76 |
# meaning you can use it as you would any other runnable. |
|
|
77 |
# Note that we're (optionally) passing the memory |
|
|
78 |
# when compiling the graph |
|
|
79 |
app = workflow.compile(checkpointer=checkpointer, |
|
|
80 |
name="T2KG_Agent") |
|
|
81 |
logger.log(logging.INFO, |
|
|
82 |
"Compiled the graph with thread_id %s and llm_model %s", |
|
|
83 |
uniq_id, |
|
|
84 |
llm_model) |
|
|
85 |
|
|
|
86 |
return app |