a b/test/layers/test_conv.py
1
import math
2
import unittest
3
4
import torch
5
from torch.nn import functional as F
6
7
from torchdrug import data, layers
8
9
10
class GraphConvTest(unittest.TestCase):
11
12
    def setUp(self):
13
        self.num_node = 10
14
        self.num_relation = 3
15
        self.input_dim = 5
16
        self.output_dim = 8
17
        adjacency = torch.rand(self.num_node, self.num_node, self.num_relation)
18
        threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0]
19
        adjacency = adjacency * (adjacency > threshold)
20
        self.graph = data.Graph.from_dense(adjacency).cuda()
21
        self.input = torch.rand(self.num_node, self.input_dim).cuda()
22
23
    def attention(self, query, key, value, mask, activation, eps=1e-10):
24
        weight = F.linear(key, query).squeeze(-1)
25
        weight = activation(weight)
26
        infinite = torch.tensor(math.inf, device=value.device)
27
        weight = torch.where(mask > 0, weight, -infinite)
28
        attention = (weight - weight.max(dim=0, keepdim=True)[0]).exp()
29
        attention = attention * mask
30
        attention = attention / (attention.sum(dim=0, keepdim=True) + eps)
31
        return (attention.unsqueeze(-1) * value).sum(dim=0)
32
33
    def test_graph_conv(self):
34
        conv = layers.GraphConv(self.input_dim, self.output_dim).cuda()
35
        result = conv(self.graph, self.input)
36
        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
37
        adjacency = adjacency + torch.eye(self.num_node, device=adjacency.device)
38
        adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt()
39
        x = adjacency.t() @ self.input
40
        truth = conv.activation(conv.linear(x))
41
        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph convolution")
42
43
        num_head = 2
44
        conv = layers.GraphAttentionConv(self.input_dim, self.output_dim, num_head=num_head).cuda()
45
        result = conv(self.graph, self.input)
46
        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
47
        adjacency = adjacency + torch.eye(self.num_node, device=adjacency.device)
48
        hidden = conv.linear(self.input)
49
        outputs = []
50
        for h, query in zip(hidden.chunk(num_head, dim=-1), conv.query.chunk(num_head, dim=0)):
51
            value = h.unsqueeze(1).expand(-1, self.num_node, -1)
52
            key = torch.stack([h.unsqueeze(1).expand(-1, self.num_node, -1),
53
                               h.unsqueeze(0).expand(self.num_node, -1, -1)], dim=-1).flatten(-2)
54
            output = self.attention(query, key, value, adjacency, conv.leaky_relu, conv.eps)
55
            outputs.append(output)
56
        truth = torch.cat(outputs, dim=-1)
57
        truth = conv.activation(truth)
58
        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph attention convolution")
59
60
        eps = 1
61
        conv = layers.GraphIsomorphismConv(self.input_dim, self.output_dim, eps=eps).cuda()
62
        result = conv(self.graph, self.input)
63
        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
64
        x = (1 + eps) * self.input + adjacency.t() @ self.input
65
        truth = conv.activation(conv.mlp(x))
66
        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-2), "Incorrect graph isomorphism convolution")
67
68
        conv = layers.RelationalGraphConv(self.input_dim, self.output_dim, self.num_relation).cuda()
69
        result = conv(self.graph, self.input)
70
        adjacency = self.graph.adjacency.to_dense()
71
        adjacency /= adjacency.sum(dim=0, keepdim=True)
72
        x = torch.einsum("htr, hd -> trd", adjacency, self.input)
73
        x = conv.linear(x.flatten(1)) + conv.self_loop(self.input)
74
        truth = conv.activation(x)
75
        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect relational graph convolution")
76
77
        conv = layers.ChebyshevConv(self.input_dim, self.output_dim, k=2).cuda()
78
        result = conv(self.graph, self.input)
79
        adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
80
        adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt()
81
        identity = torch.eye(self.num_node, device=adjacency.device)
82
        laplacian = identity - adjacency
83
        bases = [self.input, laplacian.t() @ self.input, (2 * laplacian.t() @ laplacian.t() - identity) @ self.input]
84
        x = conv.linear(torch.cat(bases, dim=-1))
85
        truth = conv.activation(x)
86
        self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect chebyshev graph convolution")
87
88
89
if __name__ == "__main__":
90
    unittest.main()