--- a +++ b/gcn/layers/gradtest.py @@ -0,0 +1,20 @@ +import torch +from torch.autograd import gradcheck +from gcn.layers.GConv import GOF_Function + +def gradchecking(use_cuda=False): + print('-'*80) + GOF = GOF_Function.apply + device = torch.device("cuda" if use_cuda else "cpu") + + weight = torch.randn(8,8,4,3,3).to(device).double().requires_grad_() + gfb = torch.randn(4,3,3).to(device).double() + + test = gradcheck(GOF, (weight, gfb), eps=1e-6, atol=1e-4, rtol=1e-3, raise_exception=True) + print(test) + + +if __name__ == "__main__": + gradchecking() + if torch.cuda.is_available(): + gradchecking(use_cuda=True) \ No newline at end of file