[3eb847]: / test / layers / test_pool.py

Download this file

100 lines (90 with data), 5.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
import unittest
import torch
from torch.nn import functional as F
from torchdrug import data, layers
class GraphPoolTest(unittest.TestCase):
def setUp(self):
self.num_node = 10
self.num_graph = 5
self.input_dim = 5
self.output_dim = 8
self.output_node = 6
self.graphs = []
for i in range(self.num_graph):
adjacency = torch.rand(self.num_node, self.num_node)
threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0]
adjacency = adjacency * (adjacency > threshold)
node_feature = torch.rand(self.num_node, self.input_dim)
graph = data.Graph.from_dense(adjacency, node_feature).cuda()
self.graphs.append(graph)
self.graph = data.Graph.pack(self.graphs)
self.feature_layer = layers.GraphConv(self.input_dim, self.output_dim).cuda()
self.pool_layer = layers.GraphConv(self.input_dim, self.output_dim).cuda()
def test_pool(self):
for sparse in [False, True]:
pool = layers.DiffPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer,
zero_diagonal=True, sparse=sparse).cuda()
pooled, result, assignment = pool(self.graph, self.graph.node_feature)
result = result.view(self.num_graph, self.output_node, -1)
result_adj = torch.stack([g.adjacency.to_dense() for g in pooled])
feature = pool.feature_layer(self.graph, self.graph.node_feature).view(self.num_graph, self.num_node, -1)
x = pool.linear(pool.pool_layer(self.graph, self.graph.node_feature))
if not sparse:
assignment = F.softmax(x, dim=-1)
assignment = assignment.view(self.num_graph, self.num_node, -1)
adjacency = torch.stack([g.adjacency.to_dense() for g in self.graph])
truth = torch.einsum("bna, bnd -> bad", assignment, feature)
truth_adj = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment)
index = torch.arange(self.output_node, device=truth.device)
truth_adj[:, index, index] = 0
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")
graph = self.graph[0]
rng_state = torch.get_rng_state()
pooled, result, assignment = pool(graph, graph.node_feature)
result_adj = pooled.adjacency.to_dense()
torch.set_rng_state(rng_state)
feature = pool.feature_layer(graph, graph.node_feature)
x = pool.linear(pool.pool_layer(graph, graph.node_feature))
if not sparse:
assignment = F.softmax(x, dim=-1)
adjacency = graph.adjacency.to_dense()
truth = torch.einsum("na, nd -> ad", assignment, feature)
truth_adj = torch.einsum("na, nm, mc -> ac", assignment, adjacency, assignment)
index = torch.arange(self.output_node, device=truth.device)
truth_adj[index, index] = 0
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")
pool = layers.MinCutPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer).cuda()
all_loss = torch.tensor(0, dtype=torch.float32, device="cuda")
result_metric = {}
pooled, result, assignment = pool(self.graph, self.graph.node_feature, all_loss, result_metric)
result = result.view(self.num_graph, self.output_node, -1)
result_adj = torch.stack([g.adjacency.to_dense() for g in pooled])
feature = pool.feature_layer(self.graph, self.graph.node_feature).view(self.num_graph, self.num_node, -1)
x = pool.linear(pool.pool_layer(self.graph, self.graph.node_feature))
assignment = F.softmax(x, dim=-1)
assignment = assignment.view(self.num_graph, self.num_node, -1)
adjacency = torch.stack([g.adjacency.to_dense() for g in self.graph])
truth = torch.einsum("bna, bnd -> bad", assignment, feature)
adjacency = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment)
truth_adj = adjacency.clone()
index = torch.arange(self.output_node, device=truth.device)
truth_adj[:, index, index] = 0
num_intra = torch.einsum("baa -> b", adjacency)
degree = self.graph.degree_in.view(self.num_graph, self.num_node)
degree = torch.einsum("bna, bn, bnc -> bac", assignment, degree, assignment)
num_all = torch.einsum("baa -> b", degree)
cut_loss = (1 - num_intra / num_all).mean()
x = torch.einsum("bna, bnc -> bac", assignment, assignment)
x = x / x.flatten(-2).norm(dim=-1, keepdim=True).unsqueeze(-1)
x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5)
regularization = x.flatten(-2).norm(dim=-1).mean()
truth_metric = {"normalized cut loss": cut_loss, "orthogonal regularization": regularization}
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect min cut pool feature")
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect min cut pool adjcency")
for key in result_metric:
self.assertTrue(torch.allclose(result_metric[key], truth_metric[key], rtol=1e-3, atol=1e-4),
"Incorrect min cut pool metric")
if __name__ == "__main__":
unittest.main()