[3eb847]: / test / layers / test_spmm.py

Download this file

99 lines (85 with data), 5.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import unittest
from itertools import product
import torch
import torch_scatter
from torchdrug import data, utils
from torchdrug.layers import functional
class SPMMTest(unittest.TestCase):
def setUp(self):
self.num_node = 50
self.num_relation = 10
self.dim = 20
adjacency = torch.rand(self.num_node, self.num_node)
threshold = adjacency.flatten().kthvalue((self.num_node - 10) * self.num_node)[0]
adjacency = adjacency * (adjacency > threshold)
self.graph = data.Graph.from_dense(adjacency)
rel_adjacency = torch.rand(self.num_node, self.num_node, self.num_relation)
threshold = rel_adjacency.flatten().kthvalue((self.num_node - 10) * self.num_node)[0]
rel_adjacency = rel_adjacency * (rel_adjacency > threshold)
self.knowledge_graph = data.Graph.from_dense(rel_adjacency)
self.relation = torch.rand(self.num_relation, self.dim)
self.input = torch.rand(self.num_node, self.dim)
self.output_grad = torch.rand(self.num_node, self.dim)
self.operators = [("add", "mul"), ("min", "mul"), ("max", "mul"), ("min", "add"), ("max", "add")]
self.devices = ["CPU", "CUDA"]
def test_spmm(self):
for device, (sum_op, mul_op) in product(self.devices, self.operators):
if device == "CUDA":
self.graph = self.graph.cuda()
self.input = self.input.cuda()
self.output_grad = self.output_grad.cuda()
self.graph.edge_weight.requires_grad_()
self.input.requires_grad_()
node_in, node_out = self.graph.edge_list.t()
result = functional.generalized_spmm(self.graph.adjacency.t(), self.input, sum=sum_op, mul=mul_op)
sum_func = getattr(torch_scatter, "scatter_%s" % sum_op)
mul_func = getattr(torch, mul_op)
edge_weight = self.graph.edge_weight.unsqueeze(-1)
message = mul_func(edge_weight, self.input[node_in])
truth = sum_func(message, node_out, dim=0, dim_size=self.num_node)
if isinstance(truth, tuple):
truth = truth[0]
self.assertTrue(torch.allclose(result, truth),
"Incorrect generalized spmm forward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
result_edge, result_input = torch.autograd.grad(
result, (self.graph.edge_weight, self.input), self.output_grad)
truth_edge, truth_input = torch.autograd.grad(
truth, (self.graph.edge_weight, self.input), self.output_grad)
self.assertTrue(torch.allclose(result_edge, truth_edge),
"Incorrect generalized spmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
self.assertTrue(torch.allclose(result_input, truth_input),
"Incorrect generalized spmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
def test_rspmm(self):
for device, (sum_op, mul_op) in product(self.devices, self.operators):
if device == "CUDA":
self.knowledge_graph = self.knowledge_graph.cuda()
self.relation = self.relation.cuda()
self.input = self.input.cuda()
self.output_grad = self.output_grad.cuda()
self.knowledge_graph.edge_weight.requires_grad_()
self.relation.requires_grad_()
self.input.requires_grad_()
result = functional.generalized_rspmm(self.knowledge_graph.adjacency.transpose(0, 1),
self.relation, self.input, sum=sum_op, mul=mul_op)
sum_func = getattr(torch_scatter, "scatter_%s" % sum_op)
mul_func = getattr(torch, mul_op)
node_in, node_out, relation = self.knowledge_graph.edge_list.t()
edge_weight = self.knowledge_graph.edge_weight.unsqueeze(-1)
message = mul_func(self.relation[relation], self.input[node_in])
truth = sum_func(edge_weight * message, node_out, dim=0, dim_size=self.num_node)
if isinstance(truth, tuple):
truth = truth[0]
self.assertTrue(torch.allclose(result, truth),
"Incorrect generalized rspmm forward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
result_edge, result_relation, result_input = torch.autograd.grad(
result, (self.knowledge_graph.edge_weight, self.relation, self.input), self.output_grad)
truth_edge, truth_relation, truth_input = torch.autograd.grad(
truth, (self.knowledge_graph.edge_weight, self.relation, self.input), self.output_grad)
self.assertTrue(torch.allclose(result_edge, truth_edge),
"Incorrect generalized rspmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
self.assertTrue(torch.allclose(result_relation, truth_relation),
"Incorrect generalized rspmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
self.assertTrue(torch.allclose(result_input, truth_input),
"Incorrect generalized rspmm backward (sum=`%s`, mul=`%s`)" % (sum_op, mul_op))
if __name__ == "__main__":
unittest.main()