Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py
@@ -0,0 +1,143 @@
+"""
+Tool for performing Graph RAG reasoning.
+"""
+
+import logging
+from typing import Type, Annotated
+from pydantic import BaseModel, Field
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.messages import ToolMessage
+from langchain_core.tools.base import InjectedToolCallId
+from langchain_core.tools import BaseTool
+from langchain_core.vectorstores import InMemoryVectorStore
+from langchain.chains.retrieval import create_retrieval_chain
+from langchain.chains.combine_documents import create_stuff_documents_chain
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.document_loaders import PyPDFLoader
+from langgraph.types import Command
+from langgraph.prebuilt import InjectedState
+import hydra
+
+# Initialize logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class GraphRAGReasoningInput(BaseModel):
+    """
+    GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
+
+    Args:
+        state: Injected state.
+        prompt: Prompt to interact with the backend.
+        extraction_name: Name assigned to the subgraph extraction process
+    """
+
+    tool_call_id: Annotated[str, InjectedToolCallId] = Field(
+        description="Tool call ID."
+    )
+    state: Annotated[dict, InjectedState] = Field(description="Injected state.")
+    prompt: str = Field(description="Prompt to interact with the backend.")
+    extraction_name: str = Field(
+        description="""Name assigned to the subgraph extraction process
+                    when the subgraph_extraction tool is invoked."""
+    )
+
+
+class GraphRAGReasoningTool(BaseTool):
+    """
+    This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
+    over user's request by considering textualized subgraph context and document context.
+    """
+
+    name: str = "graphrag_reasoning"
+    description: str = """A tool to perform reasoning using a Graph RAG approach
+                        by considering textualized subgraph context and document context."""
+    args_schema: Type[BaseModel] = GraphRAGReasoningInput
+
+    def _run(
+        self,
+        tool_call_id: Annotated[str, InjectedToolCallId],
+        state: Annotated[dict, InjectedState],
+        prompt: str,
+        extraction_name: str,
+    ):
+        """
+        Run the Graph RAG reasoning tool.
+
+        Args:
+            tool_call_id: The tool call ID.
+            state: The injected state.
+            prompt: The prompt to interact with the backend.
+            extraction_name: The name assigned to the subgraph extraction process.
+        """
+        logger.log(
+            logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
+        )
+
+        # Load Hydra configuration
+        with hydra.initialize(version_base=None, config_path="../configs"):
+            cfg = hydra.compose(
+                config_name="config", overrides=["tools/graphrag_reasoning=default"]
+            )
+            cfg = cfg.tools.graphrag_reasoning
+
+        # Prepare documents
+        all_docs = []
+        if len(state["uploaded_files"]) != 0:
+            for uploaded_file in state["uploaded_files"]:
+                if uploaded_file["file_type"] == "drug_data":
+                    # Load documents
+                    raw_documents = PyPDFLoader(
+                        file_path=uploaded_file["file_path"]
+                    ).load()
+
+                    # Split documents
+                    # May need to find an optimal chunk size and overlap configuration
+                    documents = RecursiveCharacterTextSplitter(
+                        chunk_size=cfg.splitter_chunk_size,
+                        chunk_overlap=cfg.splitter_chunk_overlap,
+                    ).split_documents(raw_documents)
+
+                    # Add documents to the list
+                    all_docs.extend(documents)
+
+        # Load the extracted graph
+        extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
+        # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
+
+        # Set another prompt template
+        prompt_template = ChatPromptTemplate.from_messages(
+            [("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
+        )
+
+        # Prepare chain with retrieved documents
+        qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
+        rag_chain = create_retrieval_chain(
+            InMemoryVectorStore.from_documents(
+                documents=all_docs, embedding=state["embedding_model"]
+            ).as_retriever(
+                search_type=cfg.retriever_search_type,
+                search_kwargs={
+                    "k": cfg.retriever_k,
+                    "fetch_k": cfg.retriever_fetch_k,
+                    "lambda_mult": cfg.retriever_lambda_mult,
+                },
+            ),
+            qa_chain,
+        )
+
+        # Invoke the chain
+        response = rag_chain.invoke(
+            {
+                "input": prompt,
+                "subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
+            }
+        )
+
+        return Command(
+            update={
+                # update the message history
+                "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
+            }
+        )