--- a +++ b/test/layers/test_sampler.py @@ -0,0 +1,50 @@ +import unittest + +import torch + +from torchdrug import data, layers + + +class GraphSamplerTest(unittest.TestCase): + + def setUp(self): + self.num_node = 10 + self.input_dim = 5 + self.output_dim = 7 + adjacency = torch.rand(self.num_node, self.num_node) + threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0] + adjacency = adjacency * (adjacency > threshold) + self.graph = data.Graph.from_dense(adjacency).cuda() + self.input = torch.rand(self.num_node, self.input_dim).cuda() + + def test_sampler(self): + conv = layers.GraphConv(self.input_dim, self.output_dim, activation=None).cuda() + readout = layers.SumReadout().cuda() + + sampler = layers.NodeSampler(ratio=0.8).cuda() + results = [] + for i in range(2000): + graph = sampler(self.graph) + node_feature = conv(graph, self.input) + result = readout(graph, node_feature) + results.append(result) + result = torch.stack(results).mean(dim=0) + node_feature = conv(self.graph, self.input) + truth = readout(self.graph, node_feature) + self.assertTrue(torch.allclose(result, truth, rtol=5e-2, atol=5e-2), "Found bias in node sampler") + + sampler = layers.EdgeSampler(ratio=0.8).cuda() + results = [] + for i in range(2000): + graph = sampler(self.graph) + node_feature = conv(graph, self.input) + result = readout(graph, node_feature) + results.append(result) + result = torch.stack(results).mean(dim=0) + node_feature = conv(self.graph, self.input) + truth = readout(self.graph, node_feature) + self.assertTrue(torch.allclose(result, truth, rtol=5e-2, atol=5e-2), "Found bias in edge sampler") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file