Diff of /test/layers/test_conv.py [000000] .. [36b44b]

Switch to side-by-side view

--- a
+++ b/test/layers/test_conv.py
@@ -0,0 +1,90 @@
+import math
+import unittest
+
+import torch
+from torch.nn import functional as F
+
+from torchdrug import data, layers
+
+
+class GraphConvTest(unittest.TestCase):
+
+    def setUp(self):
+        self.num_node = 10
+        self.num_relation = 3
+        self.input_dim = 5
+        self.output_dim = 8
+        adjacency = torch.rand(self.num_node, self.num_node, self.num_relation)
+        threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0]
+        adjacency = adjacency * (adjacency > threshold)
+        self.graph = data.Graph.from_dense(adjacency).cuda()
+        self.input = torch.rand(self.num_node, self.input_dim).cuda()
+
+    def attention(self, query, key, value, mask, activation, eps=1e-10):
+        weight = F.linear(key, query).squeeze(-1)
+        weight = activation(weight)
+        infinite = torch.tensor(math.inf, device=value.device)
+        weight = torch.where(mask > 0, weight, -infinite)
+        attention = (weight - weight.max(dim=0, keepdim=True)[0]).exp()
+        attention = attention * mask
+        attention = attention / (attention.sum(dim=0, keepdim=True) + eps)
+        return (attention.unsqueeze(-1) * value).sum(dim=0)
+
+    def test_graph_conv(self):
+        conv = layers.GraphConv(self.input_dim, self.output_dim).cuda()
+        result = conv(self.graph, self.input)
+        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
+        adjacency = adjacency + torch.eye(self.num_node, device=adjacency.device)
+        adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt()
+        x = adjacency.t() @ self.input
+        truth = conv.activation(conv.linear(x))
+        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph convolution")
+
+        num_head = 2
+        conv = layers.GraphAttentionConv(self.input_dim, self.output_dim, num_head=num_head).cuda()
+        result = conv(self.graph, self.input)
+        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
+        adjacency = adjacency + torch.eye(self.num_node, device=adjacency.device)
+        hidden = conv.linear(self.input)
+        outputs = []
+        for h, query in zip(hidden.chunk(num_head, dim=-1), conv.query.chunk(num_head, dim=0)):
+            value = h.unsqueeze(1).expand(-1, self.num_node, -1)
+            key = torch.stack([h.unsqueeze(1).expand(-1, self.num_node, -1),
+                               h.unsqueeze(0).expand(self.num_node, -1, -1)], dim=-1).flatten(-2)
+            output = self.attention(query, key, value, adjacency, conv.leaky_relu, conv.eps)
+            outputs.append(output)
+        truth = torch.cat(outputs, dim=-1)
+        truth = conv.activation(truth)
+        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph attention convolution")
+
+        eps = 1
+        conv = layers.GraphIsomorphismConv(self.input_dim, self.output_dim, eps=eps).cuda()
+        result = conv(self.graph, self.input)
+        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
+        x = (1 + eps) * self.input + adjacency.t() @ self.input
+        truth = conv.activation(conv.mlp(x))
+        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-2), "Incorrect graph isomorphism convolution")
+
+        conv = layers.RelationalGraphConv(self.input_dim, self.output_dim, self.num_relation).cuda()
+        result = conv(self.graph, self.input)
+        adjacency = self.graph.adjacency.to_dense()
+        adjacency /= adjacency.sum(dim=0, keepdim=True)
+        x = torch.einsum("htr, hd -> trd", adjacency, self.input)
+        x = conv.linear(x.flatten(1)) + conv.self_loop(self.input)
+        truth = conv.activation(x)
+        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect relational graph convolution")
+
+        conv = layers.ChebyshevConv(self.input_dim, self.output_dim, k=2).cuda()
+        result = conv(self.graph, self.input)
+        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
+        adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt()
+        identity = torch.eye(self.num_node, device=adjacency.device)
+        laplacian = identity - adjacency
+        bases = [self.input, laplacian.t() @ self.input, (2 * laplacian.t() @ laplacian.t() - identity) @ self.input]
+        x = conv.linear(torch.cat(bases, dim=-1))
+        truth = conv.activation(x)
+        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect chebyshev graph convolution")
+
+
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file