[3af7d7]: / aiagents4pharma / talk2knowledgegraphs / tools / graphrag_reasoning.py

Download this file

144 lines (123 with data), 5.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Tool for performing Graph RAG reasoning.
"""
import logging
from typing import Type, Annotated
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
from langchain_core.tools import BaseTool
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langgraph.types import Command
from langgraph.prebuilt import InjectedState
import hydra
# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GraphRAGReasoningInput(BaseModel):
"""
GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
Args:
state: Injected state.
prompt: Prompt to interact with the backend.
extraction_name: Name assigned to the subgraph extraction process
"""
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
description="Tool call ID."
)
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
prompt: str = Field(description="Prompt to interact with the backend.")
extraction_name: str = Field(
description="""Name assigned to the subgraph extraction process
when the subgraph_extraction tool is invoked."""
)
class GraphRAGReasoningTool(BaseTool):
"""
This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
over user's request by considering textualized subgraph context and document context.
"""
name: str = "graphrag_reasoning"
description: str = """A tool to perform reasoning using a Graph RAG approach
by considering textualized subgraph context and document context."""
args_schema: Type[BaseModel] = GraphRAGReasoningInput
def _run(
self,
tool_call_id: Annotated[str, InjectedToolCallId],
state: Annotated[dict, InjectedState],
prompt: str,
extraction_name: str,
):
"""
Run the Graph RAG reasoning tool.
Args:
tool_call_id: The tool call ID.
state: The injected state.
prompt: The prompt to interact with the backend.
extraction_name: The name assigned to the subgraph extraction process.
"""
logger.log(
logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
)
# Load Hydra configuration
with hydra.initialize(version_base=None, config_path="../configs"):
cfg = hydra.compose(
config_name="config", overrides=["tools/graphrag_reasoning=default"]
)
cfg = cfg.tools.graphrag_reasoning
# Prepare documents
all_docs = []
if len(state["uploaded_files"]) != 0:
for uploaded_file in state["uploaded_files"]:
if uploaded_file["file_type"] == "drug_data":
# Load documents
raw_documents = PyPDFLoader(
file_path=uploaded_file["file_path"]
).load()
# Split documents
# May need to find an optimal chunk size and overlap configuration
documents = RecursiveCharacterTextSplitter(
chunk_size=cfg.splitter_chunk_size,
chunk_overlap=cfg.splitter_chunk_overlap,
).split_documents(raw_documents)
# Add documents to the list
all_docs.extend(documents)
# Load the extracted graph
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
# Set another prompt template
prompt_template = ChatPromptTemplate.from_messages(
[("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
)
# Prepare chain with retrieved documents
qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
rag_chain = create_retrieval_chain(
InMemoryVectorStore.from_documents(
documents=all_docs, embedding=state["embedding_model"]
).as_retriever(
search_type=cfg.retriever_search_type,
search_kwargs={
"k": cfg.retriever_k,
"fetch_k": cfg.retriever_fetch_k,
"lambda_mult": cfg.retriever_lambda_mult,
},
),
qa_chain,
)
# Invoke the chain
response = rag_chain.invoke(
{
"input": prompt,
"subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
}
)
return Command(
update={
# update the message history
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
}
)