--- a +++ b/test/layers/test_pool.py @@ -0,0 +1,100 @@ +import unittest + +import torch +from torch.nn import functional as F + +from torchdrug import data, layers + + +class GraphPoolTest(unittest.TestCase): + + def setUp(self): + self.num_node = 10 + self.num_graph = 5 + self.input_dim = 5 + self.output_dim = 8 + self.output_node = 6 + self.graphs = [] + for i in range(self.num_graph): + adjacency = torch.rand(self.num_node, self.num_node) + threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0] + adjacency = adjacency * (adjacency > threshold) + node_feature = torch.rand(self.num_node, self.input_dim) + graph = data.Graph.from_dense(adjacency, node_feature).cuda() + self.graphs.append(graph) + self.graph = data.Graph.pack(self.graphs) + self.feature_layer = layers.GraphConv(self.input_dim, self.output_dim).cuda() + self.pool_layer = layers.GraphConv(self.input_dim, self.output_dim).cuda() + + def test_pool(self): + for sparse in [False, True]: + pool = layers.DiffPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer, + zero_diagonal=True, sparse=sparse).cuda() + pooled, result, assignment = pool(self.graph, self.graph.node_feature) + result = result.view(self.num_graph, self.output_node, -1) + result_adj = torch.stack([g.adjacency.to_dense() for g in pooled]) + feature = pool.feature_layer(self.graph, self.graph.node_feature).view(self.num_graph, self.num_node, -1) + x = pool.linear(pool.pool_layer(self.graph, self.graph.node_feature)) + if not sparse: + assignment = F.softmax(x, dim=-1) + assignment = assignment.view(self.num_graph, self.num_node, -1) + adjacency = torch.stack([g.adjacency.to_dense() for g in self.graph]) + truth = torch.einsum("bna, bnd -> bad", assignment, feature) + truth_adj = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment) + index = torch.arange(self.output_node, device=truth.device) + truth_adj[:, index, index] = 0 + self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature") + self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency") + + graph = self.graph[0] + rng_state = torch.get_rng_state() + pooled, result, assignment = pool(graph, graph.node_feature) + result_adj = pooled.adjacency.to_dense() + torch.set_rng_state(rng_state) + feature = pool.feature_layer(graph, graph.node_feature) + x = pool.linear(pool.pool_layer(graph, graph.node_feature)) + if not sparse: + assignment = F.softmax(x, dim=-1) + adjacency = graph.adjacency.to_dense() + truth = torch.einsum("na, nd -> ad", assignment, feature) + truth_adj = torch.einsum("na, nm, mc -> ac", assignment, adjacency, assignment) + index = torch.arange(self.output_node, device=truth.device) + truth_adj[index, index] = 0 + self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature") + self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency") + + pool = layers.MinCutPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer).cuda() + all_loss = torch.tensor(0, dtype=torch.float32, device="cuda") + result_metric = {} + pooled, result, assignment = pool(self.graph, self.graph.node_feature, all_loss, result_metric) + result = result.view(self.num_graph, self.output_node, -1) + result_adj = torch.stack([g.adjacency.to_dense() for g in pooled]) + feature = pool.feature_layer(self.graph, self.graph.node_feature).view(self.num_graph, self.num_node, -1) + x = pool.linear(pool.pool_layer(self.graph, self.graph.node_feature)) + assignment = F.softmax(x, dim=-1) + assignment = assignment.view(self.num_graph, self.num_node, -1) + adjacency = torch.stack([g.adjacency.to_dense() for g in self.graph]) + truth = torch.einsum("bna, bnd -> bad", assignment, feature) + adjacency = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment) + truth_adj = adjacency.clone() + index = torch.arange(self.output_node, device=truth.device) + truth_adj[:, index, index] = 0 + num_intra = torch.einsum("baa -> b", adjacency) + degree = self.graph.degree_in.view(self.num_graph, self.num_node) + degree = torch.einsum("bna, bn, bnc -> bac", assignment, degree, assignment) + num_all = torch.einsum("baa -> b", degree) + cut_loss = (1 - num_intra / num_all).mean() + x = torch.einsum("bna, bnc -> bac", assignment, assignment) + x = x / x.flatten(-2).norm(dim=-1, keepdim=True).unsqueeze(-1) + x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5) + regularization = x.flatten(-2).norm(dim=-1).mean() + truth_metric = {"normalized cut loss": cut_loss, "orthogonal regularization": regularization} + self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect min cut pool feature") + self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect min cut pool adjcency") + for key in result_metric: + self.assertTrue(torch.allclose(result_metric[key], truth_metric[key], rtol=1e-3, atol=1e-4), + "Incorrect min cut pool metric") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file