|
a |
|
b/test/layers/test_conv.py |
|
|
1 |
import math |
|
|
2 |
import unittest |
|
|
3 |
|
|
|
4 |
import torch |
|
|
5 |
from torch.nn import functional as F |
|
|
6 |
|
|
|
7 |
from torchdrug import data, layers |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class GraphConvTest(unittest.TestCase): |
|
|
11 |
|
|
|
12 |
def setUp(self): |
|
|
13 |
self.num_node = 10 |
|
|
14 |
self.num_relation = 3 |
|
|
15 |
self.input_dim = 5 |
|
|
16 |
self.output_dim = 8 |
|
|
17 |
adjacency = torch.rand(self.num_node, self.num_node, self.num_relation) |
|
|
18 |
threshold = adjacency.flatten().kthvalue((self.num_node - 3) * self.num_node)[0] |
|
|
19 |
adjacency = adjacency * (adjacency > threshold) |
|
|
20 |
self.graph = data.Graph.from_dense(adjacency).cuda() |
|
|
21 |
self.input = torch.rand(self.num_node, self.input_dim).cuda() |
|
|
22 |
|
|
|
23 |
def attention(self, query, key, value, mask, activation, eps=1e-10): |
|
|
24 |
weight = F.linear(key, query).squeeze(-1) |
|
|
25 |
weight = activation(weight) |
|
|
26 |
infinite = torch.tensor(math.inf, device=value.device) |
|
|
27 |
weight = torch.where(mask > 0, weight, -infinite) |
|
|
28 |
attention = (weight - weight.max(dim=0, keepdim=True)[0]).exp() |
|
|
29 |
attention = attention * mask |
|
|
30 |
attention = attention / (attention.sum(dim=0, keepdim=True) + eps) |
|
|
31 |
return (attention.unsqueeze(-1) * value).sum(dim=0) |
|
|
32 |
|
|
|
33 |
def test_graph_conv(self): |
|
|
34 |
conv = layers.GraphConv(self.input_dim, self.output_dim).cuda() |
|
|
35 |
result = conv(self.graph, self.input) |
|
|
36 |
adjacency = self.graph.adjacency.to_dense().sum(dim=-1) |
|
|
37 |
adjacency = adjacency + torch.eye(self.num_node, device=adjacency.device) |
|
|
38 |
adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt() |
|
|
39 |
x = adjacency.t() @ self.input |
|
|
40 |
truth = conv.activation(conv.linear(x)) |
|
|
41 |
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph convolution") |
|
|
42 |
|
|
|
43 |
num_head = 2 |
|
|
44 |
conv = layers.GraphAttentionConv(self.input_dim, self.output_dim, num_head=num_head).cuda() |
|
|
45 |
result = conv(self.graph, self.input) |
|
|
46 |
adjacency = self.graph.adjacency.to_dense().sum(dim=-1) |
|
|
47 |
adjacency = adjacency + torch.eye(self.num_node, device=adjacency.device) |
|
|
48 |
hidden = conv.linear(self.input) |
|
|
49 |
outputs = [] |
|
|
50 |
for h, query in zip(hidden.chunk(num_head, dim=-1), conv.query.chunk(num_head, dim=0)): |
|
|
51 |
value = h.unsqueeze(1).expand(-1, self.num_node, -1) |
|
|
52 |
key = torch.stack([h.unsqueeze(1).expand(-1, self.num_node, -1), |
|
|
53 |
h.unsqueeze(0).expand(self.num_node, -1, -1)], dim=-1).flatten(-2) |
|
|
54 |
output = self.attention(query, key, value, adjacency, conv.leaky_relu, conv.eps) |
|
|
55 |
outputs.append(output) |
|
|
56 |
truth = torch.cat(outputs, dim=-1) |
|
|
57 |
truth = conv.activation(truth) |
|
|
58 |
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph attention convolution") |
|
|
59 |
|
|
|
60 |
eps = 1 |
|
|
61 |
conv = layers.GraphIsomorphismConv(self.input_dim, self.output_dim, eps=eps).cuda() |
|
|
62 |
result = conv(self.graph, self.input) |
|
|
63 |
adjacency = self.graph.adjacency.to_dense().sum(dim=-1) |
|
|
64 |
x = (1 + eps) * self.input + adjacency.t() @ self.input |
|
|
65 |
truth = conv.activation(conv.mlp(x)) |
|
|
66 |
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-2), "Incorrect graph isomorphism convolution") |
|
|
67 |
|
|
|
68 |
conv = layers.RelationalGraphConv(self.input_dim, self.output_dim, self.num_relation).cuda() |
|
|
69 |
result = conv(self.graph, self.input) |
|
|
70 |
adjacency = self.graph.adjacency.to_dense() |
|
|
71 |
adjacency /= adjacency.sum(dim=0, keepdim=True) |
|
|
72 |
x = torch.einsum("htr, hd -> trd", adjacency, self.input) |
|
|
73 |
x = conv.linear(x.flatten(1)) + conv.self_loop(self.input) |
|
|
74 |
truth = conv.activation(x) |
|
|
75 |
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect relational graph convolution") |
|
|
76 |
|
|
|
77 |
conv = layers.ChebyshevConv(self.input_dim, self.output_dim, k=2).cuda() |
|
|
78 |
result = conv(self.graph, self.input) |
|
|
79 |
adjacency = self.graph.adjacency.to_dense().sum(dim=-1) |
|
|
80 |
adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt() |
|
|
81 |
identity = torch.eye(self.num_node, device=adjacency.device) |
|
|
82 |
laplacian = identity - adjacency |
|
|
83 |
bases = [self.input, laplacian.t() @ self.input, (2 * laplacian.t() @ laplacian.t() - identity) @ self.input] |
|
|
84 |
x = conv.linear(torch.cat(bases, dim=-1)) |
|
|
85 |
truth = conv.activation(x) |
|
|
86 |
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect chebyshev graph convolution") |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
if __name__ == "__main__": |
|
|
90 |
unittest.main() |