--- a +++ b/test/layers/test_variadic.py @@ -0,0 +1,65 @@ +import unittest + +import torch + +from torchdrug.layers import functional + + +class VariadicTest(unittest.TestCase): + + def setUp(self): + self.num_graph = 4 + self.size = torch.randint(3, 6, (self.num_graph,)) + self.num_node = self.size.sum() + self.feature_dim = 2 + self.input = torch.rand(self.num_node, self.feature_dim) + self.padded = torch.zeros(self.num_graph, self.size.max(), self.feature_dim) + self.padded[:] = float("-inf") + offset = 0 + for i, size in enumerate(self.size): + self.padded[i, :size] = self.input[offset: offset + size] + offset += size + + def test_arange(self): + result = functional.variadic_arange(self.size) + truth = torch.cat([torch.arange(x) for x in self.size]) + self.assertTrue(torch.equal(result, truth), "Incorrect variadic arange") + + def test_sort(self): + result_value, result_index = functional.variadic_sort(self.input, self.size, descending=True) + truth_value, truth_index = self.padded.sort(dim=1, descending=True) + mask = ~torch.isinf(self.padded) + truth_value = truth_value[mask].view(-1, self.feature_dim) + truth_index = truth_index[mask].view(-1, self.feature_dim) + self.assertTrue(torch.equal(result_value, truth_value), "Incorrect variadic sort") + self.assertTrue(torch.equal(result_index, truth_index), "Incorrect variadic sort") + + def test_topk(self): + for k in [self.size.min(), self.size.max()]: + result_value, result_index = functional.variadic_topk(self.input, self.size, k) + truth_value, truth_index = self.padded.topk(k, dim=1) + for i, size in enumerate(self.size): + for j in range(size, k): + truth_value[i, j] = truth_value[i, j-1] + truth_index[i, j] = truth_index[i, j-1] + self.assertTrue(torch.equal(result_value, truth_value), "Incorrect variadic topk") + self.assertTrue(torch.equal(result_index, truth_index), "Incorrect variadic topk") + + for _ in range(10): + k = torch.randint(self.size.min(), self.size.max(), (self.num_graph,)) + result_value, result_index = functional.variadic_topk(self.input, self.size, k) + _truth_value, _truth_index = self.padded.topk(self.size.max(), dim=1) + truth_value, truth_index = [], [] + for i, size in enumerate(self.size): + truth_value.append(_truth_value[i, :k[i]]) + truth_index.append(_truth_index[i, :k[i]]) + for j in range(size, k[i].item()): + truth_value[i][j] = truth_value[i][j-1] + truth_index[i][j] = truth_index[i][j-1] + truth_value = torch.cat(truth_value, dim=0) + truth_index = torch.cat(truth_index, dim=0) + self.assertTrue(torch.equal(result_value, truth_value), "Incorrect variadic topk") + self.assertTrue(torch.equal(result_index, truth_index), "Incorrect variadic topk") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file