|
a |
|
b/gcn/layers/gradtest.py |
|
|
1 |
import torch |
|
|
2 |
from torch.autograd import gradcheck |
|
|
3 |
from gcn.layers.GConv import GOF_Function |
|
|
4 |
|
|
|
5 |
def gradchecking(use_cuda=False): |
|
|
6 |
print('-'*80) |
|
|
7 |
GOF = GOF_Function.apply |
|
|
8 |
device = torch.device("cuda" if use_cuda else "cpu") |
|
|
9 |
|
|
|
10 |
weight = torch.randn(8,8,4,3,3).to(device).double().requires_grad_() |
|
|
11 |
gfb = torch.randn(4,3,3).to(device).double() |
|
|
12 |
|
|
|
13 |
test = gradcheck(GOF, (weight, gfb), eps=1e-6, atol=1e-4, rtol=1e-3, raise_exception=True) |
|
|
14 |
print(test) |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
if __name__ == "__main__": |
|
|
18 |
gradchecking() |
|
|
19 |
if torch.cuda.is_available(): |
|
|
20 |
gradchecking(use_cuda=True) |