Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py
@@ -0,0 +1,126 @@
+"""
+Tool for performing subgraph summarization.
+"""
+
+import logging
+from typing import Type, Annotated
+from pydantic import BaseModel, Field
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.messages import ToolMessage
+from langchain_core.tools.base import InjectedToolCallId
+from langchain_core.tools import BaseTool
+from langgraph.types import Command
+from langgraph.prebuilt import InjectedState
+import hydra
+
+# Initialize logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class SubgraphSummarizationInput(BaseModel):
+    """
+    SubgraphSummarizationInput is a Pydantic model representing an input for
+    summarizing a given textualized subgraph.
+
+    Args:
+        tool_call_id: Tool call ID.
+        state: Injected state.
+        prompt: Prompt to interact with the backend.
+        extraction_name: Name assigned to the subgraph extraction process
+    """
+
+    tool_call_id: Annotated[str, InjectedToolCallId] = Field(
+        description="Tool call ID."
+    )
+    state: Annotated[dict, InjectedState] = Field(description="Injected state.")
+    prompt: str = Field(description="Prompt to interact with the backend.")
+    extraction_name: str = Field(
+        description="""Name assigned to the subgraph extraction process
+                                 when the subgraph_extraction tool is invoked."""
+    )
+
+
+class SubgraphSummarizationTool(BaseTool):
+    """
+    This tool performs subgraph summarization over textualized graph to highlight the most
+    important information in responding to user's prompt.
+    """
+
+    name: str = "subgraph_summarization"
+    description: str = """A tool to perform subgraph summarization over textualized graph
+                        for responding to user's follow-up prompt(s)."""
+    args_schema: Type[BaseModel] = SubgraphSummarizationInput
+
+    def _run(
+        self,
+        tool_call_id: Annotated[str, InjectedToolCallId],
+        state: Annotated[dict, InjectedState],
+        prompt: str,
+        extraction_name: str,
+    ):
+        """
+        Run the subgraph summarization tool.
+
+        Args:
+            tool_call_id: The tool call ID.
+            state: The injected state.
+            prompt: The prompt to interact with the backend.
+            extraction_name: The name assigned to the subgraph extraction process.
+        """
+        logger.log(
+            logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
+        )
+
+        # Load hydra configuration
+        with hydra.initialize(version_base=None, config_path="../configs"):
+            cfg = hydra.compose(
+                config_name="config", overrides=["tools/subgraph_summarization=default"]
+            )
+            cfg = cfg.tools.subgraph_summarization
+
+        # Load the extracted graph
+        extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
+        # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
+
+        # Prepare prompt template
+        prompt_template = ChatPromptTemplate.from_messages(
+            [
+                ("system", cfg.prompt_subgraph_summarization),
+                ("human", "{input}"),
+            ]
+        )
+
+        # Prepare chain
+        chain = prompt_template | state["llm_model"] | StrOutputParser()
+
+        # Return the subgraph and textualized graph as JSON response
+        response = chain.invoke(
+            {
+                "input": prompt,
+                "textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
+            }
+        )
+
+        # Store the response as graph_summary in the extracted graph
+        for key, value in extracted_graph.items():
+            if key == extraction_name:
+                value["graph_summary"] = response
+
+        # Prepare the dictionary of updated state
+        dic_updated_state_for_model = {}
+        for key, value in {
+            "dic_extracted_graph": list(extracted_graph.values()),
+        }.items():
+            if value:
+                dic_updated_state_for_model[key] = value
+
+        # Return the updated state of the tool
+        return Command(
+            update=dic_updated_state_for_model
+            | {
+                # update the message history
+                "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
+            }
+        )