Diff of /gcn/layers/gradtest.py [000000] .. [f77492]

Switch to unified view

a b/gcn/layers/gradtest.py
1
import torch
2
from torch.autograd import gradcheck
3
from gcn.layers.GConv import GOF_Function
4
5
def gradchecking(use_cuda=False):
6
    print('-'*80)
7
    GOF = GOF_Function.apply
8
    device = torch.device("cuda" if use_cuda else "cpu")
9
10
    weight = torch.randn(8,8,4,3,3).to(device).double().requires_grad_()
11
    gfb = torch.randn(4,3,3).to(device).double()
12
13
    test = gradcheck(GOF, (weight, gfb), eps=1e-6, atol=1e-4, rtol=1e-3, raise_exception=True)
14
    print(test)
15
16
17
if __name__ == "__main__":
18
    gradchecking()
19
    if torch.cuda.is_available():
20
        gradchecking(use_cuda=True)