Diff of /torchdrug/layers/pool.py [000000] .. [36b44b]

Switch to side-by-side view

--- a
+++ b/torchdrug/layers/pool.py
@@ -0,0 +1,207 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch_scatter import scatter_add, scatter_mean
+
+from torchdrug import data
+
+
+class DiffPool(nn.Module):
+    """
+    Differentiable pooling operator from `Hierarchical Graph Representation Learning with Differentiable Pooling`_
+
+    .. _Hierarchical Graph Representation Learning with Differentiable Pooling:
+        https://papers.nips.cc/paper/7729-hierarchical-graph-representation-learning-with-differentiable-pooling.pdf
+
+    Parameter
+        input_dim (int): input dimension
+        output_node (int): number of nodes after pooling
+        feature_layer (Module, optional): graph convolution layer for embedding
+        pool_layer (Module, optional): graph convolution layer for pooling assignment
+        loss_weight (float, optional): weight of entropy regularization
+        zero_diagonal (bool, optional): remove self loops in the pooled graph or not
+        sparse (bool, optional): use sparse assignment or not
+    """
+
+    tau = 1
+    eps = 1e-10
+
+    def __init__(self, input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=False,
+                 sparse=False):
+        super(DiffPool, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = feature_layer.output_dim
+        self.output_node = output_node
+        self.feature_layer = feature_layer
+        self.pool_layer = pool_layer
+        self.loss_weight = loss_weight
+        self.zero_diagonal = zero_diagonal
+        self.sparse = sparse
+
+        if pool_layer is not None:
+            self.linear = nn.Linear(pool_layer.output_dim, output_node)
+        else:
+            self.linear = nn.Linear(input_dim, output_node)
+
+    def forward(self, graph, input, all_loss=None, metric=None):
+        """
+        Compute the node cluster assignment and pool the nodes.
+
+        Parameters:
+            graph (Graph): graph(s)
+            input (Tensor): input node representations
+            all_loss (Tensor, optional): if specified, add loss to this tensor
+            metric (dict, optional): if specified, output metrics to this dict
+
+        Returns:
+            (PackedGraph, Tensor, Tensor):
+                pooled graph, output node representations, node-to-cluster assignment
+        """
+        feature = input
+        if self.feature_layer:
+            feature = self.feature_layer(graph, feature)
+
+        x = input
+        if self.pool_layer:
+            x = self.pool_layer(graph, x)
+        x = self.linear(x)
+        if self.sparse:
+            assignment = F.gumbel_softmax(x, hard=True, tau=self.tau, dim=-1)
+            new_graph, output = self.sparse_pool(graph, feature, assignment)
+        else:
+            assignment = F.softmax(x, dim=-1)
+            new_graph, output = self.dense_pool(graph, feature, assignment)
+
+        if all_loss is not None:
+            prob = scatter_mean(assignment, graph.node2graph, dim=0, dim_size=graph.batch_size)
+            entropy = -(prob * (prob + self.eps).log()).sum(dim=-1)
+            entropy = entropy.mean()
+            metric["assignment entropy"] = entropy
+            if self.loss_weight > 0:
+                all_loss -= entropy * self.loss_weight
+
+        if self.zero_diagonal:
+            edge_list = new_graph.edge_list[:, :2]
+            is_diagonal = edge_list[:, 0] == edge_list[:, 1]
+            new_graph = new_graph.edge_mask(~is_diagonal)
+
+        return new_graph, output, assignment
+
+    def dense_pool(self, graph, input, assignment):
+        node_in, node_out = graph.edge_list.t()[:2]
+        # S^T A S, O(|V|k^2 + |E|k)
+        x = graph.edge_weight.unsqueeze(-1) * assignment[node_out]
+        x = scatter_add(x, node_in, dim=0, dim_size=graph.num_node)
+        x = torch.einsum("np, nq -> npq", assignment, x)
+        adjacency = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size)
+        # S^T X
+        x = torch.einsum("na, nd -> nad", assignment, input)
+        output = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size).flatten(0, 1)
+
+        index = torch.arange(self.output_node, device=graph.device).expand(len(graph), self.output_node, -1)
+        edge_list = torch.stack([index.transpose(-1, -2), index], dim=-1).flatten(0, -2)
+        edge_weight = adjacency.flatten()
+        if isinstance(graph, data.PackedGraph):
+            num_nodes = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node
+            num_edges = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node ** 2
+            graph = data.PackedGraph(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges)
+        else:
+            graph = data.Graph(edge_list, edge_weight=edge_weight, num_node=self.output_node)
+        return graph, output
+
+    def sparse_pool(self, graph, input, assignment):
+        assignment = assignment.argmax(dim=-1)
+        edge_list = graph.edge_list[:, :2]
+        edge_list = assignment[edge_list]
+        pooled_node = graph.node2graph * self.output_node + assignment
+        output = scatter_add(input, pooled_node, dim=0, dim_size=graph.batch_size * self.output_node)
+
+        edge_weight = graph.edge_weight
+        if isinstance(graph, data.PackedGraph):
+            num_nodes = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node
+            num_edges = graph.num_edges
+            graph = data.PackedGraph(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges)
+        else:
+            graph = data.Graph(edge_list, edge_weight=edge_weight, num_node=self.output_node)
+        return graph, output
+
+
+class MinCutPool(DiffPool):
+    """
+    Min cut pooling operator from `Spectral Clustering with Graph Neural Networks for Graph Pooling`_
+
+    .. _Spectral Clustering with Graph Neural Networks for Graph Pooling:
+        http://proceedings.mlr.press/v119/bianchi20a/bianchi20a.pdf
+
+    Parameters:
+        input_dim (int): input dimension
+        output_node (int): number of nodes after pooling
+        feature_layer (Module, optional): graph convolution layer for embedding
+        pool_layer (Module, optional): graph convolution layer for pooling assignment
+        loss_weight (float, optional): weight of entropy regularization
+        zero_diagonal (bool, optional): remove self loops in the pooled graph or not
+        sparse (bool, optional): use sparse assignment or not
+    """
+
+    eps = 1e-10
+
+    def __init__(self, input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=True,
+                 sparse=False):
+        super(MinCutPool, self).__init__(input_dim, output_node, feature_layer, pool_layer, loss_weight, zero_diagonal,
+                                         sparse)
+
+    def forward(self, graph, input, all_loss=None, metric=None):
+        """
+        Compute the node cluster assignment and pool the nodes.
+
+        Parameters:
+            graph (Graph): graph(s)
+            input (Tensor): input node representations
+            all_loss (Tensor, optional): if specified, add loss to this tensor
+            metric (dict, optional): if specified, output metrics to this dict
+
+        Returns:
+            (PackedGraph, Tensor, Tensor):
+                pooled graph, output node representations, node-to-cluster assignment
+        """
+        feature = input
+        if self.feature_layer:
+            feature = self.feature_layer(graph, feature)
+
+        x = input
+        if self.pool_layer:
+            x = self.pool_layer(graph, x)
+        x = self.linear(x)
+        if self.sparse:
+            assignment = F.gumbel_softmax(x, hard=True, tau=self.tau, dim=-1)
+            new_graph, output = self.sparse_pool(graph, feature, assignment)
+        else:
+            assignment = F.softmax(x, dim=-1)
+            new_graph, output = self.dense_pool(graph, feature, assignment)
+
+        if all_loss is not None:
+            edge_list = new_graph.edge_list
+            is_diagonal = edge_list[:, 0] == edge_list[:, 1]
+            num_intra = scatter_add(new_graph.edge_weight[is_diagonal], new_graph.edge2graph[is_diagonal],
+                                    dim=0, dim_size=new_graph.batch_size)
+            x = torch.einsum("na, n, nc -> nac", assignment, graph.degree_in, assignment)
+            x = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size)
+            num_all = torch.einsum("baa -> b", x)
+            cut_loss = (1 - num_intra / (num_all + self.eps)).mean()
+            metric["normalized cut loss"] = cut_loss
+
+            x = torch.einsum("na, nc -> nac", assignment, assignment)
+            x = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size)
+            x = x / x.flatten(-2).norm(dim=-1, keepdim=True).unsqueeze(-1)
+            x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5)
+            regularization = x.flatten(-2).norm(dim=-1).mean()
+            metric["orthogonal regularization"] = regularization
+            if self.loss_weight > 0:
+                all_loss += (cut_loss + regularization) * self.loss_weight
+
+        if self.zero_diagonal:
+            edge_list = new_graph.edge_list[:, :2]
+            is_diagonal = edge_list[:, 0] == edge_list[:, 1]
+            new_graph = new_graph.edge_mask(~is_diagonal)
+
+        return new_graph, output, assignment
\ No newline at end of file