--- a +++ b/test/layers/test_readout.py @@ -0,0 +1,52 @@ +import unittest + +import torch +from torch.nn import functional as F + +from torchdrug import data, layers + + +class GraphReadoutTest(unittest.TestCase): + + def setUp(self): + self.num_node = 10 + self.num_graph = 5 + self.feature_dim = 5 + self.graphs = [] + for i in range(self.num_graph): + adjacency = torch.rand(self.num_node, self.num_node) + threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0] + adjacency = adjacency * (adjacency > threshold) + node_feature = torch.rand(self.num_node, self.feature_dim) + graph = data.Graph.from_dense(adjacency, node_feature).cuda() + self.graphs.append(graph) + self.graph = data.Graph.pack(self.graphs) + + def test_readout(self): + readout = layers.SumReadout().cuda() + result = readout(self.graph, self.graph.node_feature) + truth = [graph.node_feature.sum(0) for graph in self.graphs] + truth = torch.stack(truth) + self.assertTrue(torch.allclose(result, truth), "Incorrect sum readout") + + readout = layers.MeanReadout().cuda() + result = readout(self.graph, self.graph.node_feature) + truth = [graph.node_feature.mean(0) for graph in self.graphs] + truth = torch.stack(truth) + self.assertTrue(torch.allclose(result, truth), "Incorrect mean readout") + + readout = layers.MaxReadout().cuda() + result = readout(self.graph, self.graph.node_feature) + truth = [graph.node_feature.max(0)[0] for graph in self.graphs] + truth = torch.stack(truth) + self.assertTrue(torch.allclose(result, truth), "Incorrect max readout") + + softmax = layers.Softmax().cuda() + result = softmax(self.graph, self.graph.node_feature) + truth = [F.softmax(graph.node_feature, dim=0) for graph in self.graphs] + truth = torch.cat(truth) + self.assertTrue(torch.allclose(result, truth), "Incorrect softmax") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file