[f77492]: / gcn / csrc / cpu / GOF_cpu.cpp

Download this file

127 lines (112 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include "cpu/vision.h"
template <typename T>
void GOFForward_cpu_kernel(
const T* weight_data,
const T* gaborFilterBank_data,
const int nOutputPlane,
const int nInputPlane,
const int nChannel,
const int kH,
const int kW,
T* output_data) {
for (int i = 0; i < nOutputPlane; i++) {
for (int j = 0; j < nInputPlane; j++) {
for (int l = 0; l < nChannel * kH * kW; l++) {
T val = *(weight_data + i * (nInputPlane * nChannel * kH * kW)
+ j * (nChannel * kH * kW)
+ l);
for (int k = 0; k < nChannel; k++) {
T gabortmp = *(gaborFilterBank_data + k * (kW * kH)
+ l % (kW * kH));
T *target = output_data + i * (nChannel * nInputPlane * nChannel * kH * kW)
+ k * (nInputPlane * nChannel * kH * kW)
+ j * (nChannel * kH * kW)
+ l;
*target = val * gabortmp;
}
}
}
}
}
template <typename T>
void GOFBackward_cpu_kernel(
const T* grad_output_data,
const T* gaborFilterBank_data,
const int nOutputPlane,
const int nInputPlane,
const int nChannel,
const int kH,
const int kW,
T* grad_weight_data) {
const int nEntry = nChannel * kH * kW;
for (int i = 0; i < nOutputPlane; i++) {
for (int j = 0; j < nInputPlane; j++) {
for (int l = 0; l < nEntry; l++) {
T *val = grad_weight_data + i * (nInputPlane * nEntry)
+ j * (nEntry) + l;
*val = 0;
for (int k = 0; k < nChannel; k++) {
T gabortmp = *(gaborFilterBank_data + k * (kW * kH)
+ l % (kW * kH));
T target = *(grad_output_data + i * (nChannel * nInputPlane * nEntry)
+ k * (nInputPlane * nEntry)
+ j * (nEntry)
+ l);
*val = *val + target * gabortmp;
}
}
}
}
}
at::Tensor GOF_forward_cpu(const at::Tensor& weight,
const at::Tensor& gaborFilterBank) {
AT_ASSERTM(!weight.type().is_cuda(), "weight must be a CPU tensor");
AT_ASSERTM(!gaborFilterBank.type().is_cuda(), "gaborFilterBank must be a CPU tensor");
auto nOutputPlane = weight.size(0);
auto nInputPlane = weight.size(1);
auto nChannel = weight.size(2);
auto kH = weight.size(3);
auto kW = weight.size(4);
auto output = at::empty({nOutputPlane * nChannel, nInputPlane * nChannel, kH, kW}, weight.options());
if (output.numel() == 0) {
return output;
}
AT_DISPATCH_FLOATING_TYPES(weight.type(), "GOF_forward", [&] {
GOFForward_cpu_kernel<scalar_t>(
weight.data<scalar_t>(),
gaborFilterBank.data<scalar_t>(),
nOutputPlane,
nInputPlane,
nChannel,
kH,
kW,
output.data<scalar_t>());
});
return output;
}
at::Tensor GOF_backward_cpu(const at::Tensor& grad_output,
const at::Tensor& gaborFilterBank) {
AT_ASSERTM(!grad_output.type().is_cuda(), "grad_output must be a CPU tensor");
AT_ASSERTM(!gaborFilterBank.type().is_cuda(), "gaborFilterBank must be a CPU tensor");
auto nChannel = gaborFilterBank.size(0);
auto nOutputPlane = grad_output.size(0) / nChannel;
auto nInputPlane = grad_output.size(1) / nChannel;
auto kH = grad_output.size(2);
auto kW = grad_output.size(3);
auto grad_weight = at::empty({nOutputPlane, nInputPlane, nChannel, kH, kW}, grad_output.options());
if (grad_weight.numel() == 0) {
return grad_weight;
}
AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "GOF_backward", [&] {
GOFBackward_cpu_kernel<scalar_t>(
grad_output.data<scalar_t>(),
gaborFilterBank.data<scalar_t>(),
nOutputPlane,
nInputPlane,
nChannel,
kH,
kW,
grad_weight.data<scalar_t>());
});
return grad_weight;
}