Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py
@@ -0,0 +1,305 @@
+"""
+Tool for performing subgraph extraction.
+"""
+
+from typing import Type, Annotated
+import logging
+import pickle
+import numpy as np
+import pandas as pd
+import hydra
+import networkx as nx
+from pydantic import BaseModel, Field
+from langchain.chains.retrieval import create_retrieval_chain
+from langchain.chains.combine_documents import create_stuff_documents_chain
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.vectorstores import InMemoryVectorStore
+from langchain_core.tools import BaseTool
+from langchain_core.messages import ToolMessage
+from langchain_core.tools.base import InjectedToolCallId
+from langchain_community.document_loaders import PyPDFLoader
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langgraph.types import Command
+from langgraph.prebuilt import InjectedState
+import torch
+from torch_geometric.data import Data
+from ..utils.extractions.pcst import PCSTPruning
+from ..utils.embeddings.ollama import EmbeddingWithOllama
+from .load_arguments import ArgumentData
+
+# Initialize logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class SubgraphExtractionInput(BaseModel):
+    """
+    SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
+
+    Args:
+        prompt: Prompt to interact with the backend.
+        tool_call_id: Tool call ID.
+        state: Injected state.
+        arg_data: Argument for analytical process over graph data.
+    """
+
+    tool_call_id: Annotated[str, InjectedToolCallId] = Field(
+        description="Tool call ID."
+    )
+    state: Annotated[dict, InjectedState] = Field(description="Injected state.")
+    prompt: str = Field(description="Prompt to interact with the backend.")
+    arg_data: ArgumentData = Field(
+        description="Experiment over graph data.", default=None
+    )
+
+
+class SubgraphExtractionTool(BaseTool):
+    """
+    This tool performs subgraph extraction based on user's prompt by taking into account
+    the top-k nodes and edges.
+    """
+
+    name: str = "subgraph_extraction"
+    description: str = "A tool for subgraph extraction based on user's prompt."
+    args_schema: Type[BaseModel] = SubgraphExtractionInput
+
+    def perform_endotype_filtering(
+        self,
+        prompt: str,
+        state: Annotated[dict, InjectedState],
+        cfg: hydra.core.config_store.ConfigStore,
+    ) -> str:
+        """
+        Perform endotype filtering based on the uploaded files and prepare the prompt.
+
+        Args:
+            prompt: The prompt to interact with the backend.
+            state: Injected state for the tool.
+            cfg: Hydra configuration object.
+        """
+        # Loop through the uploaded files
+        all_genes = []
+        for uploaded_file in state["uploaded_files"]:
+            if uploaded_file["file_type"] == "endotype":
+                # Load the PDF file
+                docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
+
+                # Split the text into chunks
+                splits = RecursiveCharacterTextSplitter(
+                    chunk_size=cfg.splitter_chunk_size,
+                    chunk_overlap=cfg.splitter_chunk_overlap,
+                ).split_documents(docs)
+
+                # Create a chat prompt template
+                prompt_template = ChatPromptTemplate.from_messages(
+                    [
+                        ("system", cfg.prompt_endotype_filtering),
+                        ("human", "{input}"),
+                    ]
+                )
+
+                qa_chain = create_stuff_documents_chain(
+                    state["llm_model"], prompt_template
+                )
+                rag_chain = create_retrieval_chain(
+                    InMemoryVectorStore.from_documents(
+                        documents=splits, embedding=state["embedding_model"]
+                    ).as_retriever(
+                        search_type=cfg.retriever_search_type,
+                        search_kwargs={
+                            "k": cfg.retriever_k,
+                            "fetch_k": cfg.retriever_fetch_k,
+                            "lambda_mult": cfg.retriever_lambda_mult,
+                        },
+                    ),
+                    qa_chain,
+                )
+                results = rag_chain.invoke({"input": prompt})
+                all_genes.append(results["answer"])
+
+        # Prepare the prompt
+        if len(all_genes) > 0:
+            prompt = " ".join(
+                [prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
+            )
+
+        return prompt
+
+    def prepare_final_subgraph(self,
+                               subgraph: dict,
+                               pyg_graph: Data,
+                               textualized_graph: pd.DataFrame) -> dict:
+        """
+        Prepare the subgraph based on the extracted subgraph.
+
+        Args:
+            subgraph: The extracted subgraph.
+            pyg_graph: The PyTorch Geometric graph.
+            textualized_graph: The textualized graph.
+
+        Returns:
+            A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
+        """
+        # print(subgraph)
+        # Prepare the PyTorch Geometric graph
+        mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
+        pyg_graph = Data(
+            # Node features
+            x=pyg_graph.x[subgraph["nodes"]],
+            node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
+            node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
+            enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
+            num_nodes=len(subgraph["nodes"]),
+            # Edge features
+            edge_index=torch.LongTensor(
+                [
+                    [
+                        mapping[i]
+                        for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
+                    ],
+                    [
+                        mapping[i]
+                        for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
+                    ],
+                ]
+            ),
+            edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
+            edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
+            relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
+            label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
+            enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
+        )
+
+        # Networkx DiGraph construction to be visualized in the frontend
+        nx_graph = nx.DiGraph()
+        for n in pyg_graph.node_name:
+            nx_graph.add_node(n)
+        for i, e in enumerate(
+            [
+                [pyg_graph.node_name[i], pyg_graph.node_name[j]]
+                for (i, j) in pyg_graph.edge_index.transpose(1, 0)
+            ]
+        ):
+            nx_graph.add_edge(
+                e[0],
+                e[1],
+                relation=pyg_graph.edge_type[i],
+                label=pyg_graph.edge_type[i],
+            )
+
+        # Prepare the textualized subgraph
+        textualized_graph = (
+            textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
+            + "\n"
+            + textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
+        )
+
+        return {
+            "graph_pyg": pyg_graph,
+            "graph_nx": nx_graph,
+            "graph_text": textualized_graph,
+        }
+
+    def _run(
+        self,
+        tool_call_id: Annotated[str, InjectedToolCallId],
+        state: Annotated[dict, InjectedState],
+        prompt: str,
+        arg_data: ArgumentData = None,
+    ) -> Command:
+        """
+        Run the subgraph extraction tool.
+
+        Args:
+            tool_call_id: The tool call ID for the tool.
+            state: Injected state for the tool.
+            prompt: The prompt to interact with the backend.
+            arg_data (ArgumentData): The argument data.
+
+        Returns:
+            Command: The command to be executed.
+        """
+        logger.log(logging.INFO, "Invoking subgraph_extraction tool")
+
+        # Load hydra configuration
+        with hydra.initialize(version_base=None, config_path="../configs"):
+            cfg = hydra.compose(
+                config_name="config", overrides=["tools/subgraph_extraction=default"]
+            )
+            cfg = cfg.tools.subgraph_extraction
+
+        # Retrieve source graph from the state
+        initial_graph = {}
+        initial_graph["source"] = state["dic_source_graph"][-1]  # The last source graph as of now
+        # logger.log(logging.INFO, "Source graph: %s", source_graph)
+
+        # Load the knowledge graph
+        with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
+            initial_graph["pyg"] = pickle.load(f)
+        with open(initial_graph["source"]["kg_text_path"], "rb") as f:
+            initial_graph["text"] = pickle.load(f)
+
+        # Prepare prompt construction along with a list of endotypes
+        if len(state["uploaded_files"]) != 0 and "endotype" in [
+            f["file_type"] for f in state["uploaded_files"]
+        ]:
+            prompt = self.perform_endotype_filtering(prompt, state, cfg)
+
+        # Prepare embedding model and embed the user prompt as query
+        query_emb = torch.tensor(
+            EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
+        ).float()
+
+        # Prepare the PCSTPruning object and extract the subgraph
+        # Parameters were set in the configuration file obtained from Hydra
+        subgraph = PCSTPruning(
+            state["topk_nodes"],
+            state["topk_edges"],
+            cfg.cost_e,
+            cfg.c_const,
+            cfg.root,
+            cfg.num_clusters,
+            cfg.pruning,
+            cfg.verbosity_level,
+        ).extract_subgraph(initial_graph["pyg"], query_emb)
+
+        # Prepare subgraph as a NetworkX graph and textualized graph
+        final_subgraph = self.prepare_final_subgraph(
+            subgraph, initial_graph["pyg"], initial_graph["text"]
+        )
+
+        # Prepare the dictionary of extracted graph
+        dic_extracted_graph = {
+            "name": arg_data.extraction_name,
+            "tool_call_id": tool_call_id,
+            "graph_source": initial_graph["source"]["name"],
+            "topk_nodes": state["topk_nodes"],
+            "topk_edges": state["topk_edges"],
+            "graph_dict": {
+                "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
+                "edges": list(final_subgraph["graph_nx"].edges(data=True)),
+            },
+            "graph_text": final_subgraph["graph_text"],
+            "graph_summary": None,
+        }
+
+        # Prepare the dictionary of updated state
+        dic_updated_state_for_model = {}
+        for key, value in {
+            "dic_extracted_graph": [dic_extracted_graph],
+        }.items():
+            if value:
+                dic_updated_state_for_model[key] = value
+
+        # Return the updated state of the tool
+        return Command(
+            update=dic_updated_state_for_model | {
+                # update the message history
+                "messages": [
+                    ToolMessage(
+                        content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
+                        tool_call_id=tool_call_id,
+                    )
+                ],
+            }
+        )