--- a
+++ b/test/layers/test_spmm.py
@@ -0,0 +1,99 @@
+import unittest
+
+from itertools import product
+
+import torch
+import torch_scatter
+
+from torchdrug import data, utils
+from torchdrug.layers import functional
+
+
+class SPMMTest(unittest.TestCase):
+
+    def setUp(self):
+        self.num_node = 50
+        self.num_relation = 10
+        self.dim = 20
+        adjacency = torch.rand(self.num_node, self.num_node)
+        threshold = adjacency.flatten().kthvalue((self.num_node - 10) * self.num_node)[0]
+        adjacency = adjacency * (adjacency > threshold)
+        self.graph = data.Graph.from_dense(adjacency)
+        rel_adjacency = torch.rand(self.num_node, self.num_node, self.num_relation)
+        threshold = rel_adjacency.flatten().kthvalue((self.num_node - 10) * self.num_node)[0]
+        rel_adjacency = rel_adjacency * (rel_adjacency > threshold)
+        self.knowledge_graph = data.Graph.from_dense(rel_adjacency)
+        self.relation = torch.rand(self.num_relation, self.dim)
+        self.input = torch.rand(self.num_node, self.dim)
+        self.output_grad = torch.rand(self.num_node, self.dim)
+        self.operators = [("add", "mul"), ("min", "mul"), ("max", "mul"), ("min", "add"), ("max", "add")]
+        self.devices = ["CPU", "CUDA"]
+
+    def test_spmm(self):
+        for device, (sum_op, mul_op) in product(self.devices, self.operators):
+            if device == "CUDA":
+                self.graph = self.graph.cuda()
+                self.input = self.input.cuda()
+                self.output_grad = self.output_grad.cuda()
+            self.graph.edge_weight.requires_grad_()
+            self.input.requires_grad_()
+
+            node_in, node_out = self.graph.edge_list.t()
+            result = functional.generalized_spmm(self.graph.adjacency.t(), self.input, sum=sum_op, mul=mul_op)
+            sum_func = getattr(torch_scatter, "scatter_%s" % sum_op)
+            mul_func = getattr(torch, mul_op)
+            edge_weight = self.graph.edge_weight.unsqueeze(-1)
+            message = mul_func(edge_weight, self.input[node_in])
+            truth = sum_func(message, node_out, dim=0, dim_size=self.num_node)
+            if isinstance(truth, tuple):
+                truth = truth[0]
+            self.assertTrue(torch.allclose(result, truth),
+                            "Incorrect generalized spmm forward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
+
+            result_edge, result_input = torch.autograd.grad(
+                result, (self.graph.edge_weight, self.input), self.output_grad)
+            truth_edge, truth_input = torch.autograd.grad(
+                truth, (self.graph.edge_weight, self.input), self.output_grad)
+            self.assertTrue(torch.allclose(result_edge, truth_edge),
+                            "Incorrect generalized spmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
+            self.assertTrue(torch.allclose(result_input, truth_input),
+                            "Incorrect generalized spmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
+
+    def test_rspmm(self):
+        for device, (sum_op, mul_op) in product(self.devices, self.operators):
+            if device == "CUDA":
+                self.knowledge_graph = self.knowledge_graph.cuda()
+                self.relation = self.relation.cuda()
+                self.input = self.input.cuda()
+                self.output_grad = self.output_grad.cuda()
+            self.knowledge_graph.edge_weight.requires_grad_()
+            self.relation.requires_grad_()
+            self.input.requires_grad_()
+
+            result = functional.generalized_rspmm(self.knowledge_graph.adjacency.transpose(0, 1),
+                                                  self.relation, self.input, sum=sum_op, mul=mul_op)
+            sum_func = getattr(torch_scatter, "scatter_%s" % sum_op)
+            mul_func = getattr(torch, mul_op)
+            node_in, node_out, relation = self.knowledge_graph.edge_list.t()
+            edge_weight = self.knowledge_graph.edge_weight.unsqueeze(-1)
+            message = mul_func(self.relation[relation], self.input[node_in])
+            truth = sum_func(edge_weight * message, node_out, dim=0, dim_size=self.num_node)
+            if isinstance(truth, tuple):
+                truth = truth[0]
+            self.assertTrue(torch.allclose(result, truth),
+                            "Incorrect generalized rspmm forward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
+
+            result_edge, result_relation, result_input = torch.autograd.grad(
+                result, (self.knowledge_graph.edge_weight, self.relation, self.input), self.output_grad)
+            truth_edge, truth_relation, truth_input = torch.autograd.grad(
+                truth, (self.knowledge_graph.edge_weight, self.relation, self.input), self.output_grad)
+            self.assertTrue(torch.allclose(result_edge, truth_edge),
+                            "Incorrect generalized rspmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
+            self.assertTrue(torch.allclose(result_relation, truth_relation),
+                            "Incorrect generalized rspmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
+            self.assertTrue(torch.allclose(result_input, truth_input),
+                            "Incorrect generalized rspmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
+
+
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file