[3eb847]: / test / layers / test_readout.py

Download this file

52 lines (41 with data), 2.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import unittest
import torch
from torch.nn import functional as F
from torchdrug import data, layers
class GraphReadoutTest(unittest.TestCase):
def setUp(self):
self.num_node = 10
self.num_graph = 5
self.feature_dim = 5
self.graphs = []
for i in range(self.num_graph):
adjacency = torch.rand(self.num_node, self.num_node)
threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0]
adjacency = adjacency * (adjacency > threshold)
node_feature = torch.rand(self.num_node, self.feature_dim)
graph = data.Graph.from_dense(adjacency, node_feature).cuda()
self.graphs.append(graph)
self.graph = data.Graph.pack(self.graphs)
def test_readout(self):
readout = layers.SumReadout().cuda()
result = readout(self.graph, self.graph.node_feature)
truth = [graph.node_feature.sum(0) for graph in self.graphs]
truth = torch.stack(truth)
self.assertTrue(torch.allclose(result, truth), "Incorrect sum readout")
readout = layers.MeanReadout().cuda()
result = readout(self.graph, self.graph.node_feature)
truth = [graph.node_feature.mean(0) for graph in self.graphs]
truth = torch.stack(truth)
self.assertTrue(torch.allclose(result, truth), "Incorrect mean readout")
readout = layers.MaxReadout().cuda()
result = readout(self.graph, self.graph.node_feature)
truth = [graph.node_feature.max(0)[0] for graph in self.graphs]
truth = torch.stack(truth)
self.assertTrue(torch.allclose(result, truth), "Incorrect max readout")
softmax = layers.Softmax().cuda()
result = softmax(self.graph, self.graph.node_feature)
truth = [F.softmax(graph.node_feature, dim=0) for graph in self.graphs]
truth = torch.cat(truth)
self.assertTrue(torch.allclose(result, truth), "Incorrect softmax")
if __name__ == "__main__":
unittest.main()