--- a
+++ b/test/data/test_graph.py
@@ -0,0 +1,389 @@
+import unittest
+
+import torch
+
+from torchdrug import data
+
+
+class GraphTest(unittest.TestCase):
+
+    def setUp(self):
+        self.num_node = 10
+        self.num_feature = 3
+        adjacency = torch.rand(self.num_node, self.num_node)
+        threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0]
+        self.adjacency = adjacency * (adjacency > threshold)
+        self.edge_list = self.adjacency.nonzero()
+        self.edge_weight = self.adjacency[self.adjacency > 0]
+        self.node_feature = torch.rand(self.num_node, self.num_feature)
+        self.edge_feature = torch.rand(len(self.edge_list), self.num_feature)
+        self.graph_feature = torch.rand(self.num_feature)
+
+    def block_diag(self, tensors):
+        total_row = 0
+        total_col = 0
+        for tensor in tensors:
+            num_row, num_col = tensor.shape
+            total_row += num_row
+            total_col += num_col
+        result = torch.zeros(total_row, total_col)
+        x = 0
+        y = 0
+        for tensor in tensors:
+            num_row, num_col = tensor.shape
+            result[x: x + num_row, y: y + num_col] = tensor
+            x += num_row
+            y += num_col
+        return result
+
+    def assert_equal(self, graph1, graph2, prompt):
+        self.assertTrue(torch.equal(graph1.adjacency.to_dense(), graph2.adjacency.to_dense()),
+                        "Incorrect edge list in %s" % prompt)
+        if hasattr(graph1, "node_feature") and hasattr(graph2, "node_feature"):
+            self.assertTrue(torch.equal(graph1.node_feature, graph2.node_feature), "Incorrect feature in %s" % prompt)
+        if hasattr(graph1, "edge_feature") and hasattr(graph2, "edge_feature"):
+            self.assertTrue(torch.equal(graph1.edge_feature, graph2.edge_feature), "Incorrect feature in %s" % prompt)
+        if hasattr(graph1, "graph_feature") and hasattr(graph2, "graph_feature"):
+            self.assertTrue(torch.equal(graph1.graph_feature, graph2.graph_feature), "Incorrect feature in %s" % prompt)
+
+    def test_type_cast(self):
+        dense_edge_feature = torch.zeros(self.num_node, self.num_node, self.num_feature)
+        dense_edge_feature[tuple(self.edge_list.t())] = self.edge_feature
+        graph = data.Graph.from_dense(self.adjacency, self.node_feature, dense_edge_feature)
+        graph1 = data.Graph(self.edge_list.tolist(), self.edge_weight.tolist(), self.num_node,
+                            node_feature=self.node_feature.tolist(), edge_feature=self.edge_feature.tolist())
+        graph2 = data.Graph(self.edge_list.numpy(), self.edge_weight.numpy(), self.num_node,
+                            node_feature=self.node_feature.numpy(), edge_feature=self.edge_feature.numpy())
+        self.assert_equal(graph, graph1, "type cast")
+        self.assert_equal(graph, graph2, "type cast")
+
+    def test_index(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
+                           node_feature=self.node_feature, edge_feature=self.edge_feature)
+
+        index = tuple(torch.randint(self.num_node, (2,)).tolist())
+        result = graph[index]
+        truth = self.adjacency[index]
+        self.assertTrue(torch.equal(result, truth), "Incorrect edge in single item")
+
+        h_index = torch.randperm(self.num_node)[:self.num_node // 2]
+        t_index = torch.randperm(self.num_node)[:self.num_node // 2]
+        new_graph = graph[h_index, t_index]
+        adj_result = new_graph.adjacency.to_dense()
+        feat_result = new_graph.node_feature
+        not_h_index = list(set(range(self.num_node)) - set(h_index.tolist()))
+        not_t_index = list(set(range(self.num_node)) - set(t_index.tolist()))
+        adj_truth = self.adjacency.clone()
+        adj_truth[not_h_index, :] = 0
+        adj_truth[:, not_t_index] = 0
+        feat_truth = self.node_feature
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in node mask")
+        self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect feature in node mask")
+
+        new_graph = graph[:, 1: -1]
+        adj_result = new_graph.adjacency.to_dense()
+        feat_result = new_graph.node_feature
+        adj_truth = torch.zeros_like(self.adjacency)
+        adj_truth[:, 1: -1] = self.adjacency[:, 1: -1]
+        feat_truth = self.node_feature
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in slice")
+        self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect feature in slice")
+
+        index = torch.randperm(self.num_node)[:self.num_node // 2]
+        new_graph = graph.subgraph(index)
+        adj_result = new_graph.adjacency.to_dense()
+        feat_result = new_graph.node_feature
+        adj_truth = self.adjacency[index][:, index]
+        feat_truth = self.node_feature[index]
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in subgraph")
+        self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect feature in subgraph")
+
+    def test_device(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
+                           node_feature=self.node_feature, edge_feature=self.edge_feature,
+                           graph_feature=self.graph_feature)
+        graph1 = graph.cuda()
+        self.assertEqual(graph1.adjacency.device.type, "cuda", "Incorrect device")
+        graph2 = graph1.cpu()
+        self.assertEqual(graph2.adjacency.device.type, "cpu", "Incorrect device")
+        self.assert_equal(graph, graph2, "device")
+
+    def test_pack(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
+                           node_feature=self.node_feature, edge_feature=self.edge_feature,
+                           graph_feature=self.graph_feature)
+        # special case: graphs with no edges
+        graphs = [graph.edge_mask([]), graph.edge_mask([])]
+        for start in range(4):
+            index = torch.arange(start, self.num_node)
+            graphs.append(graph.subgraph(index))
+        for graph in graphs:
+            with graph.graph():
+                graph.graph_feature = torch.rand_like(self.graph_feature)
+
+        packed_graph = data.Graph.pack(graphs)
+        adj_result = packed_graph.adjacency.to_dense()
+        adj_truth = self.block_diag([graph.adjacency.to_dense() for graph in graphs])
+        node_feat_result = packed_graph.node_feature
+        node_feat_truth = torch.cat([graph.node_feature for graph in graphs])
+        edge_feat_result = packed_graph.edge_feature
+        edge_feat_truth = torch.cat([graph.edge_feature for graph in graphs])
+        graph_feat_result = packed_graph.graph_feature
+        graph_feat_truth = torch.stack([graph.graph_feature for graph in graphs])
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in pack")
+        self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in pack")
+        self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in pack")
+        self.assertTrue(torch.equal(graph_feat_result, graph_feat_truth), "Incorrect feature in pack")
+
+        new_graphs = packed_graph.unpack()
+        self.assertEqual(len(graphs), len(new_graphs), "Incorrect length in unpack")
+        for graph, new_graph in zip(graphs, new_graphs):
+            self.assert_equal(graph, new_graph, "unpack")
+
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
+                           node_feature=self.node_feature, edge_feature=self.edge_feature)
+        graphs = graphs[2:]
+        packed_graph = data.Graph.pack(graphs)
+        packed_graph2 = data.Graph.pack([graph] * len(graphs))
+        mask = torch.zeros(self.num_node * len(graphs), dtype=torch.bool)
+        for start in range(4):
+            mask[start * self.num_node + start: (start + 1) * self.num_node] = 1
+        packed_graph2 = packed_graph2.subgraph(mask)
+        self.assert_equal(packed_graph, packed_graph2, "subgraph")
+
+        packed_graph = data.Graph.pack(graphs[::2])
+        packed_graph2 = data.Graph.pack(graphs)[::2]
+        self.assertEqual(len(packed_graph), len(packed_graph2), "Incorrect batch size in graph mask")
+        self.assert_equal(packed_graph, packed_graph2, "graph mask")
+
+    def test_reorder(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
+                           node_feature=self.node_feature, edge_feature=self.edge_feature)
+        order = torch.randperm(graph.num_node)
+        new_graph = graph.subgraph(order)
+        adj_result = new_graph.adjacency.to_dense()
+        adj_truth = graph.adjacency.to_dense().index_select(0, order).index_select(1, order)
+        node_feat_result = new_graph.node_feature
+        node_feat_truth = graph.node_feature[order]
+        edge_feat_result = new_graph.edge_feature
+        edge_feat_truth = graph.edge_feature
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in node reorder")
+        self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in node reorder")
+        self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in node reorder")
+
+        order = torch.randperm(graph.num_edge)
+        new_graph = graph.edge_mask(order)
+        edge_result = new_graph.edge_list
+        edge_truth = graph.edge_list[order]
+        node_feat_result = new_graph.node_feature
+        node_feat_truth = graph.node_feature
+        edge_feat_result = new_graph.edge_feature
+        edge_feat_truth = graph.edge_feature[order]
+        self.assertTrue(torch.equal(edge_result, edge_truth), "Incorrect edge list in edge reorder")
+        self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in edge reorder")
+        self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in edge reorder")
+
+        graphs = []
+        for start in range(4):
+            index = torch.arange(start, self.num_node)
+            graphs.append(graph.subgraph(index))
+        packed_graph = data.Graph.pack(graphs)
+        order = torch.randperm(4)
+        packed_graph = packed_graph.subbatch(order)
+        packed_graph2 = data.Graph.pack([graphs[i] for i in order])
+        self.assert_equal(packed_graph, packed_graph2, "graph reorder")
+
+    def test_repeat(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
+                           node_feature=self.node_feature, edge_feature=self.edge_feature)
+        repeat_graph = graph.repeat(5)
+        true_graph = data.Graph.pack([graph] * 5)
+        self.assert_equal(repeat_graph, true_graph, "repeat")
+
+        # special case: graphs with no edges
+        graphs = [graph.edge_mask([]), graph.edge_mask([])]
+        for start in range(4):
+            index = torch.arange(start, self.num_node)
+            graphs.append(graph.subgraph(index))
+        packed_graph = data.Graph.pack(graphs)
+        repeat_graph = packed_graph.repeat(5)
+        true_graph = data.Graph.pack(graphs * 5)
+        self.assert_equal(repeat_graph, true_graph, "repeat")
+
+    def test_repeat_interleave(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
+                           node_feature=self.node_feature, edge_feature=self.edge_feature)
+        # special case: graphs with no edges
+        graphs = [graph.edge_mask([]), graph.edge_mask([])]
+        for start in range(4):
+            index = torch.arange(start, self.num_node)
+            graphs.append(graph.subgraph(index))
+        packed_graph = data.Graph.pack(graphs)
+        # special case: 0 repetition
+        repeats = [2, 0, 0, 2, 3, 0]
+        repeat_graph = packed_graph.repeat_interleave(repeats)
+        true_graphs = []
+        for i, graph in zip(repeats, graphs):
+            true_graphs += [graph] * i
+        true_graph = data.Graph.pack(true_graphs)
+        self.assert_equal(repeat_graph, true_graph, "repeat interleave")
+
+    def test_repeated_index(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
+        graphs = []
+        for start in range(4):
+            index = torch.arange(start, self.num_node)
+            graphs.append(graph.subgraph(index))
+        packed_graph = data.Graph.pack(graphs)
+        # special case: some indexes missing, not sorted
+        index = [1, 0, 2, 1, 0]
+        packed_graph = packed_graph[index]
+        packed_graph2 = data.Graph.pack([graphs[i] for i in index])
+        self.assert_equal(packed_graph, packed_graph2, "repeated index")
+
+    def test_split(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
+        node2graph = torch.randint(3, (10,))
+        graphs = []
+        for i in range(3):
+            subgraph = graph.subgraph(node2graph == i)
+            if subgraph.num_node > 0:
+                graphs.append(subgraph)
+        new_graphs = graph.split(node2graph).unpack()
+        self.assertEqual(len(graphs), len(new_graphs), "Incorrect length in split")
+        for graph, new_graph in zip(graphs, new_graphs):
+            adj_truth = graph.adjacency.to_dense()
+            adj_result = new_graph.adjacency.to_dense()
+            self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect split")
+
+        graphs = []
+        for i in range(5, 10):
+            # ensure connected graph
+            edge_list = torch.stack([torch.arange(i - 1), torch.arange(1, i)], dim=-1)
+            graph = data.Graph(edge_list, num_node=i)
+            graphs.append(graph)
+        packed_graph = data.Graph.pack(graphs)
+        packed_graph2, num_cc_result = packed_graph.connected_components()
+        num_cc_truth = torch.ones_like(num_cc_result)
+        self.assertTrue(torch.equal(num_cc_result, num_cc_truth), "Incorrect connected components")
+        stat_truth = sorted((graph.num_node, graph.num_edge) for graph in packed_graph)
+        stat_result = sorted((graph.num_node, graph.num_edge) for graph in packed_graph2)
+        self.assertEqual(stat_result, stat_truth, "Incorrect connected components")
+
+        # shuffle node order
+        perm = torch.randperm(packed_graph.num_node)
+        adjacency = packed_graph.adjacency.to_dense()
+        adjacency = adjacency.index_select(0, perm).index_select(1, perm)
+        packed_graph2, num_cc_result = data.Graph.from_dense(adjacency).connected_components()
+        num_cc_truth = torch.tensor([len(graphs)])
+        self.assertTrue(torch.equal(num_cc_result, num_cc_truth), "Incorrect connected components")
+        stat_truth = sorted((graph.num_node, graph.num_edge) for graph in packed_graph)
+        stat_result = sorted((graph.num_node, graph.num_edge) for graph in packed_graph2)
+        self.assertEqual(stat_result, stat_truth, "Incorrect connected components")
+
+    def test_merge(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
+        graph2graph = torch.randint(2, (6,))
+        graph2graph[0] = 0
+        graph2graph[-1] = 1
+        graphs = []
+        for start in range(6):
+            index = torch.arange(start, self.num_node)
+            graphs.append(graph.subgraph(index))
+        packed_graph = data.Graph.pack(graphs)
+        merged_graph = packed_graph.merge(graph2graph)
+        truth_graphs = []
+        for i in range(2):
+            index = (graph2graph == i).nonzero().flatten().tolist()
+            truth_graph = data.Graph.pack([graphs[j] for j in index])
+            truth_graphs.append(truth_graph)
+        self.assertEqual(len(merged_graph), len(truth_graphs), "Incorrect length in merge")
+        for graph, truth in zip(merged_graph, truth_graphs):
+            adj_result = graph.adjacency.to_dense()
+            adj_truth = truth.adjacency.to_dense()
+            self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect merge")
+
+    def test_directed(self):
+        digraph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
+        graph = digraph.undirected()
+        adj_result = graph.adjacency.to_dense()
+        adj_truth = (digraph.adjacency + digraph.adjacency.t()).to_dense()
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect conversion from directed to undirected")
+        digraph2 = graph.directed()
+        adj_result = digraph2.adjacency.to_dense()
+        adj_truth = adj_truth.triu()
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect conversion from undirected to directed")
+
+    def test_match(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
+        index = torch.randperm(graph.num_edge)[:self.num_node]
+        edge = graph.edge_list[index]
+        mask = torch.randint(2, (len(edge), 1))
+        edge.scatter_(1, mask, -1)
+        random = torch.randint_like(edge, self.num_node)
+        edge = torch.cat([edge, random])
+        index_result, num_match_result = graph.match(edge)
+        index_results = index_result.split(num_match_result.tolist())
+        match = ((graph.edge_list.unsqueeze(0) == edge.unsqueeze(1)) | (edge.unsqueeze(1) == -1)).all(dim=-1)
+        query_index, index_truth = match.nonzero().t()
+        num_match_truth = query_index.bincount(minlength=len(edge))
+        index_truths = index_truth.split(num_match_truth.tolist())
+        self.assertTrue(torch.equal(num_match_result, num_match_truth), "Incorrect edge match")
+        for index_result, index_truth in zip(index_results, index_truths):
+            self.assertTrue(torch.equal(index_result.sort()[0], index_truth.sort()[0]), "Incorrect edge match")
+
+    def test_reference(self):
+        node_out = torch.arange(1, self.num_node)
+        node_in = torch.div(node_out - 1, 2, rounding_mode="floor")
+        edge_list = torch.stack([node_in, node_out], dim=-1)
+        tree = data.Graph(edge_list, num_node=self.num_node)
+        with tree.node(), tree.node_reference():
+            tree.dad = torch.div(torch.arange(self.num_node) - 1, 2, rounding_mode="floor")
+
+        mask = torch.arange(1, self.num_node)
+        graph = tree.subgraph(mask)
+        degree_in_result = graph.dad[graph.dad != -1].bincount(minlength=graph.num_node)
+        is_root_result = graph.dad == -1
+        node_in, node_out = graph.edge_list.t()
+        degree_in_truth = node_in.bincount(minlength=graph.num_node)
+        is_root_truth = node_out.bincount(minlength=graph.num_node) == 0
+        self.assertTrue(torch.equal(degree_in_result, degree_in_truth), "Incorrect node reference")
+        self.assertTrue(torch.equal(is_root_result, is_root_truth), "Incorrect node reference")
+
+        packed_graph = tree.repeat(4)
+        packed_graph2 = data.Graph.pack([tree] * 4)
+        self.assert_equal(packed_graph, packed_graph2, "node reference")
+
+        # special case: 0 repetition
+        repeats = [2, 0, 1, 2]
+        trees = []
+        for start in range(4):
+            index = torch.arange(start, self.num_node)
+            trees.append(tree.subgraph(index))
+        packed_graph = data.Graph.pack(trees)
+        repeat_graph = packed_graph.repeat_interleave(repeats)
+        true_graphs = []
+        for i, tree in zip(repeats, trees):
+            true_graphs += [tree] * i
+        true_graph = data.Graph.pack(true_graphs)
+        self.assert_equal(repeat_graph, true_graph, "node reference")
+
+    def test_line_graph(self):
+        graph = data.Graph(self.edge_list, self.edge_weight, self.num_node, edge_feature=self.edge_feature)
+        line_graph = graph.line_graph()
+        adj_result = line_graph.adjacency.to_dense()
+        feat_result = line_graph.node_feature
+        edge_index = torch.arange(graph.num_edge)
+        node_in, node_out = graph.edge_list.t()
+        edge2node_out = torch.zeros(graph.num_edge, graph.num_node)
+        node_in2edge = torch.zeros(graph.num_node, graph.num_edge)
+        edge2node_out[edge_index, node_out] = 1
+        node_in2edge[node_in, edge_index] = 1
+        adj_truth = edge2node_out @ node_in2edge
+        feat_truth = graph.edge_feature
+        self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect line graph")
+        self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect line graph")
+
+
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file