--- a
+++ b/test/layers/test_pool.py
@@ -0,0 +1,100 @@
+import unittest
+
+import torch
+from torch.nn import functional as F
+
+from torchdrug import data, layers
+
+
+class GraphPoolTest(unittest.TestCase):
+
+    def setUp(self):
+        self.num_node = 10
+        self.num_graph = 5
+        self.input_dim = 5
+        self.output_dim = 8
+        self.output_node = 6
+        self.graphs = []
+        for i in range(self.num_graph):
+            adjacency = torch.rand(self.num_node, self.num_node)
+            threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0]
+            adjacency = adjacency * (adjacency > threshold)
+            node_feature = torch.rand(self.num_node, self.input_dim)
+            graph = data.Graph.from_dense(adjacency, node_feature).cuda()
+            self.graphs.append(graph)
+        self.graph = data.Graph.pack(self.graphs)
+        self.feature_layer = layers.GraphConv(self.input_dim, self.output_dim).cuda()
+        self.pool_layer = layers.GraphConv(self.input_dim, self.output_dim).cuda()
+
+    def test_pool(self):
+        for sparse in [False, True]:
+            pool = layers.DiffPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer,
+                                   zero_diagonal=True, sparse=sparse).cuda()
+            pooled, result, assignment = pool(self.graph, self.graph.node_feature)
+            result = result.view(self.num_graph, self.output_node, -1)
+            result_adj = torch.stack([g.adjacency.to_dense() for g in pooled])
+            feature = pool.feature_layer(self.graph, self.graph.node_feature).view(self.num_graph, self.num_node, -1)
+            x = pool.linear(pool.pool_layer(self.graph, self.graph.node_feature))
+            if not sparse:
+                assignment = F.softmax(x, dim=-1)
+            assignment = assignment.view(self.num_graph, self.num_node, -1)
+            adjacency = torch.stack([g.adjacency.to_dense() for g in self.graph])
+            truth = torch.einsum("bna, bnd -> bad", assignment, feature)
+            truth_adj = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment)
+            index = torch.arange(self.output_node, device=truth.device)
+            truth_adj[:, index, index] = 0
+            self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
+            self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")
+
+            graph = self.graph[0]
+            rng_state = torch.get_rng_state()
+            pooled, result, assignment = pool(graph, graph.node_feature)
+            result_adj = pooled.adjacency.to_dense()
+            torch.set_rng_state(rng_state)
+            feature = pool.feature_layer(graph, graph.node_feature)
+            x = pool.linear(pool.pool_layer(graph, graph.node_feature))
+            if not sparse:
+                assignment = F.softmax(x, dim=-1)
+            adjacency = graph.adjacency.to_dense()
+            truth = torch.einsum("na, nd -> ad", assignment, feature)
+            truth_adj = torch.einsum("na, nm, mc -> ac", assignment, adjacency, assignment)
+            index = torch.arange(self.output_node, device=truth.device)
+            truth_adj[index, index] = 0
+            self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
+            self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")
+
+        pool = layers.MinCutPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer).cuda()
+        all_loss = torch.tensor(0, dtype=torch.float32, device="cuda")
+        result_metric = {}
+        pooled, result, assignment = pool(self.graph, self.graph.node_feature, all_loss, result_metric)
+        result = result.view(self.num_graph, self.output_node, -1)
+        result_adj = torch.stack([g.adjacency.to_dense() for g in pooled])
+        feature = pool.feature_layer(self.graph, self.graph.node_feature).view(self.num_graph, self.num_node, -1)
+        x = pool.linear(pool.pool_layer(self.graph, self.graph.node_feature))
+        assignment = F.softmax(x, dim=-1)
+        assignment = assignment.view(self.num_graph, self.num_node, -1)
+        adjacency = torch.stack([g.adjacency.to_dense() for g in self.graph])
+        truth = torch.einsum("bna, bnd -> bad", assignment, feature)
+        adjacency = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment)
+        truth_adj = adjacency.clone()
+        index = torch.arange(self.output_node, device=truth.device)
+        truth_adj[:, index, index] = 0
+        num_intra = torch.einsum("baa -> b", adjacency)
+        degree = self.graph.degree_in.view(self.num_graph, self.num_node)
+        degree = torch.einsum("bna, bn, bnc -> bac", assignment, degree, assignment)
+        num_all = torch.einsum("baa -> b", degree)
+        cut_loss = (1 - num_intra / num_all).mean()
+        x = torch.einsum("bna, bnc -> bac", assignment, assignment)
+        x = x / x.flatten(-2).norm(dim=-1, keepdim=True).unsqueeze(-1)
+        x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5)
+        regularization = x.flatten(-2).norm(dim=-1).mean()
+        truth_metric = {"normalized cut loss": cut_loss, "orthogonal regularization": regularization}
+        self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect min cut pool feature")
+        self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect min cut pool adjcency")
+        for key in result_metric:
+            self.assertTrue(torch.allclose(result_metric[key], truth_metric[key], rtol=1e-3, atol=1e-4),
+                            "Incorrect min cut pool metric")
+
+
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file