Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py
@@ -0,0 +1,225 @@
+"""
+Exctraction of subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
+"""
+
+from typing import Tuple, NamedTuple
+import numpy as np
+import torch
+import pcst_fast
+from torch_geometric.data.data import Data
+
+class PCSTPruning(NamedTuple):
+    """
+    Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
+    (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
+    Question Answering', NeurIPS 2024) paper.
+    https://arxiv.org/abs/2402.07630
+    https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
+
+    Args:
+        topk: The number of top nodes to consider.
+        topk_e: The number of top edges to consider.
+        cost_e: The cost of the edges.
+        c_const: The constant value for the cost of the edges computation.
+        root: The root node of the subgraph, -1 for unrooted.
+        num_clusters: The number of clusters.
+        pruning: The pruning strategy to use.
+        verbosity_level: The verbosity level.
+    """
+    topk: int = 3
+    topk_e: int = 3
+    cost_e: float = 0.5
+    c_const: float = 0.01
+    root: int = -1
+    num_clusters: int = 1
+    pruning: str = "gw"
+    verbosity_level: int = 0
+
+    def compute_prizes(self, graph: Data, query_emb: torch.Tensor) -> np.ndarray:
+        """
+        Compute the node prizes based on the cosine similarity between the query and nodes,
+        as well as the edge prizes based on the cosine similarity between the query and edges.
+        Note that the node and edge embeddings shall use the same embedding model and dimensions
+        with the query.
+
+        Args:
+            graph: The knowledge graph in PyTorch Geometric Data format.
+            query_emb: The query embedding in PyTorch Tensor format.
+
+        Returns:
+            The prizes of the nodes and edges.
+        """
+        # Compute prizes for nodes
+        n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
+        topk = min(self.topk, graph.num_nodes)
+        _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
+        n_prizes = torch.zeros_like(n_prizes)
+        n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
+
+        # Compute prizes for edges
+        # e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
+        # topk_e = min(self.topk_e, e_prizes.unique().size(0))
+        # topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
+        # e_prizes[e_prizes < topk_e_values[-1]] = 0.0
+        # last_topk_e_value = topk_e
+        # for k in range(topk_e):
+        #     indices = e_prizes == topk_e_values[k]
+        #     value = min((topk_e - k) / sum(indices), last_topk_e_value)
+        #     e_prizes[indices] = value
+        #     last_topk_e_value = value * (1 - self.c_const)
+
+        # Optimized version of the above code
+        e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
+        unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
+        topk_e = min(self.topk_e, unique_prizes.size(0))
+        topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
+        e_prizes[e_prizes < topk_e_values[-1]] = 0.0
+        last_topk_e_value = topk_e
+        for k in range(topk_e):
+            indices = inverse_indices == (
+                unique_prizes == topk_e_values[k]
+            ).nonzero(as_tuple=True)[0]
+            value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
+            e_prizes[indices] = value
+            last_topk_e_value = value * (1 - self.c_const)
+
+        return {"nodes": n_prizes, "edges": e_prizes}
+
+    def compute_subgraph_costs(
+        self, graph: Data, prizes: dict
+    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+        """
+        Compute the costs in constructing the subgraph proposed by G-Retriever paper.
+
+        Args:
+            graph: The knowledge graph in PyTorch Geometric Data format.
+            prizes: The prizes of the nodes and the edges.
+
+        Returns:
+            edges: The edges of the subgraph, consisting of edges and number of edges without
+                virtual edges.
+            prizes: The prizes of the subgraph.
+            costs: The costs of the subgraph.
+        """
+        # Logic to reduce the cost of the edges such that at least one edge is selected
+        updated_cost_e = min(
+            self.cost_e,
+            prizes["edges"].max().item() * (1 - self.c_const / 2),
+        )
+
+        # Initialize variables
+        edges = []
+        costs = []
+        virtual = {
+            "n_prizes": [],
+            "edges": [],
+            "costs": [],
+        }
+        mapping = {"nodes": {}, "edges": {}}
+
+        # Compute the costs, edges, and virtual variables based on the prizes
+        for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
+            prize_e = prizes["edges"][i]
+            if prize_e <= updated_cost_e:
+                mapping["edges"][len(edges)] = i
+                edges.append((src, dst))
+                costs.append(updated_cost_e - prize_e)
+            else:
+                virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
+                mapping["nodes"][virtual_node_id] = i
+                virtual["edges"].append((src, virtual_node_id))
+                virtual["edges"].append((virtual_node_id, dst))
+                virtual["costs"].append(0)
+                virtual["costs"].append(0)
+                virtual["n_prizes"].append(prize_e - updated_cost_e)
+        prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
+        edges_dict = {}
+        edges_dict["edges"] = edges
+        edges_dict["num_prior_edges"] = len(edges)
+        # Final computation of the costs and edges based on the virtual costs and virtual edges
+        if len(virtual["costs"]) > 0:
+            costs = np.array(costs + virtual["costs"])
+            edges = np.array(edges + virtual["edges"])
+            edges_dict["edges"] = edges
+
+        return edges_dict, prizes, costs, mapping
+
+    def get_subgraph_nodes_edges(
+        self, graph: Data, vertices: np.ndarray, edges_dict: dict, mapping: dict,
+    ) -> dict:
+        """
+        Get the selected nodes and edges of the subgraph based on the vertices and edges computed
+        by the PCST algorithm.
+
+        Args:
+            graph: The knowledge graph in PyTorch Geometric Data format.
+            vertices: The vertices of the subgraph computed by the PCST algorithm.
+            edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
+                and the number of prior edges (without virtual edges).
+            mapping: The mapping dictionary of the nodes and edges.
+            num_prior_edges: The number of edges before adding virtual edges.
+
+        Returns:
+            The selected nodes and edges of the extracted subgraph.
+        """
+        # Get edges information
+        edges = edges_dict["edges"]
+        num_prior_edges = edges_dict["num_prior_edges"]
+        # Retrieve the selected nodes and edges based on the given vertices and edges
+        subgraph_nodes = vertices[vertices < graph.num_nodes]
+        subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
+        virtual_vertices = vertices[vertices >= graph.num_nodes]
+        if len(virtual_vertices) > 0:
+            virtual_vertices = vertices[vertices >= graph.num_nodes]
+            virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
+            subgraph_edges = np.array(subgraph_edges + virtual_edges)
+        edge_index = graph.edge_index[:, subgraph_edges]
+        subgraph_nodes = np.unique(
+            np.concatenate(
+                [subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()]
+            )
+        )
+
+        return {"nodes": subgraph_nodes, "edges": subgraph_edges}
+
+    def extract_subgraph(self, graph: Data, query_emb: torch.Tensor) -> dict:
+        """
+        Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
+
+        Args:
+            graph: The knowledge graph in PyTorch Geometric Data format.
+            query_emb: The query embedding.
+
+        Returns:
+            The selected nodes and edges of the subgraph.
+        """
+        # Assert the topk and topk_e values for subgraph retrieval
+        assert self.topk > 0, "topk must be greater than or equal to 0"
+        assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
+
+        # Retrieve the top-k nodes and edges based on the query embedding
+        prizes = self.compute_prizes(graph, query_emb)
+
+        # Compute costs in constructing the subgraph
+        edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
+            graph, prizes
+        )
+
+        # Retrieve the subgraph using the PCST algorithm
+        result_vertices, result_edges = pcst_fast.pcst_fast(
+            edges_dict["edges"],
+            prizes,
+            costs,
+            self.root,
+            self.num_clusters,
+            self.pruning,
+            self.verbosity_level,
+        )
+
+        subgraph = self.get_subgraph_nodes_edges(
+            graph,
+            result_vertices,
+            {"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
+            mapping)
+
+        return subgraph