[9d3784]: / aiagents4pharma / talk2knowledgegraphs / tools / subgraph_summarization.py

Download this file

127 lines (107 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Tool for performing subgraph summarization.
"""
import logging
from typing import Type, Annotated
from pydantic import BaseModel, Field
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
from langchain_core.tools import BaseTool
from langgraph.types import Command
from langgraph.prebuilt import InjectedState
import hydra
# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SubgraphSummarizationInput(BaseModel):
"""
SubgraphSummarizationInput is a Pydantic model representing an input for
summarizing a given textualized subgraph.
Args:
tool_call_id: Tool call ID.
state: Injected state.
prompt: Prompt to interact with the backend.
extraction_name: Name assigned to the subgraph extraction process
"""
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
description="Tool call ID."
)
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
prompt: str = Field(description="Prompt to interact with the backend.")
extraction_name: str = Field(
description="""Name assigned to the subgraph extraction process
when the subgraph_extraction tool is invoked."""
)
class SubgraphSummarizationTool(BaseTool):
"""
This tool performs subgraph summarization over textualized graph to highlight the most
important information in responding to user's prompt.
"""
name: str = "subgraph_summarization"
description: str = """A tool to perform subgraph summarization over textualized graph
for responding to user's follow-up prompt(s)."""
args_schema: Type[BaseModel] = SubgraphSummarizationInput
def _run(
self,
tool_call_id: Annotated[str, InjectedToolCallId],
state: Annotated[dict, InjectedState],
prompt: str,
extraction_name: str,
):
"""
Run the subgraph summarization tool.
Args:
tool_call_id: The tool call ID.
state: The injected state.
prompt: The prompt to interact with the backend.
extraction_name: The name assigned to the subgraph extraction process.
"""
logger.log(
logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
)
# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../configs"):
cfg = hydra.compose(
config_name="config", overrides=["tools/subgraph_summarization=default"]
)
cfg = cfg.tools.subgraph_summarization
# Load the extracted graph
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
# Prepare prompt template
prompt_template = ChatPromptTemplate.from_messages(
[
("system", cfg.prompt_subgraph_summarization),
("human", "{input}"),
]
)
# Prepare chain
chain = prompt_template | state["llm_model"] | StrOutputParser()
# Return the subgraph and textualized graph as JSON response
response = chain.invoke(
{
"input": prompt,
"textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
}
)
# Store the response as graph_summary in the extracted graph
for key, value in extracted_graph.items():
if key == extraction_name:
value["graph_summary"] = response
# Prepare the dictionary of updated state
dic_updated_state_for_model = {}
for key, value in {
"dic_extracted_graph": list(extracted_graph.values()),
}.items():
if value:
dic_updated_state_for_model[key] = value
# Return the updated state of the tool
return Command(
update=dic_updated_state_for_model
| {
# update the message history
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
}
)