--- a
+++ b/test/layers/test_sampler.py
@@ -0,0 +1,50 @@
+import unittest
+
+import torch
+
+from torchdrug import data, layers
+
+
+class GraphSamplerTest(unittest.TestCase):
+
+    def setUp(self):
+        self.num_node = 10
+        self.input_dim = 5
+        self.output_dim = 7
+        adjacency = torch.rand(self.num_node, self.num_node)
+        threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0]
+        adjacency = adjacency * (adjacency > threshold)
+        self.graph = data.Graph.from_dense(adjacency).cuda()
+        self.input = torch.rand(self.num_node, self.input_dim).cuda()
+
+    def test_sampler(self):
+        conv = layers.GraphConv(self.input_dim, self.output_dim, activation=None).cuda()
+        readout = layers.SumReadout().cuda()
+
+        sampler = layers.NodeSampler(ratio=0.8).cuda()
+        results = []
+        for i in range(2000):
+            graph = sampler(self.graph)
+            node_feature = conv(graph, self.input)
+            result = readout(graph, node_feature)
+            results.append(result)
+        result = torch.stack(results).mean(dim=0)
+        node_feature = conv(self.graph, self.input)
+        truth = readout(self.graph, node_feature)
+        self.assertTrue(torch.allclose(result, truth, rtol=5e-2, atol=5e-2), "Found bias in node sampler")
+
+        sampler = layers.EdgeSampler(ratio=0.8).cuda()
+        results = []
+        for i in range(2000):
+            graph = sampler(self.graph)
+            node_feature = conv(graph, self.input)
+            result = readout(graph, node_feature)
+            results.append(result)
+        result = torch.stack(results).mean(dim=0)
+        node_feature = conv(self.graph, self.input)
+        truth = readout(self.graph, node_feature)
+        self.assertTrue(torch.allclose(result, truth, rtol=5e-2, atol=5e-2), "Found bias in edge sampler")
+
+
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file