Diff of /gcn/csrc/GOF.h [000000] .. [f77492]

Switch to unified view

a b/gcn/csrc/GOF.h
1
#pragma once
2
3
#include "cpu/vision.h"
4
5
#ifdef WITH_CUDA
6
#include "cuda/vision.h"
7
#endif
8
9
// Interface for Python
10
at::Tensor GOF_forward(const at::Tensor& weight, 
11
                       const at::Tensor& gaborFilterBank) {
12
  if (weight.type().is_cuda()) {
13
#ifdef WITH_CUDA
14
    return GOF_forward_cuda(weight, gaborFilterBank);
15
#else
16
    AT_ERROR("Not compiled with GPU support");
17
#endif
18
  }
19
  return GOF_forward_cpu(weight, gaborFilterBank);
20
}
21
22
at::Tensor GOF_backward(const at::Tensor& grad_output,
23
                        const at::Tensor& gaborFilterBank) {
24
  if (grad_output.type().is_cuda()) {
25
#ifdef WITH_CUDA
26
    return GOF_backward_cuda(grad_output, gaborFilterBank);
27
#else
28
    AT_ERROR("Not compiled with GPU support");
29
#endif
30
  }
31
  return GOF_backward_cpu(grad_output, gaborFilterBank);
32
}
33