|
a |
|
b/gcn/layers/GConv.py |
|
|
1 |
from __future__ import division |
|
|
2 |
import math |
|
|
3 |
import torch |
|
|
4 |
from torch import nn |
|
|
5 |
import torch.nn.functional as F |
|
|
6 |
from torch.autograd import Function |
|
|
7 |
from torch.nn.modules.utils import _pair |
|
|
8 |
from torch.nn.modules.conv import _ConvNd |
|
|
9 |
from torch.autograd.function import once_differentiable |
|
|
10 |
|
|
|
11 |
from gcn import _C |
|
|
12 |
# from gcn import _C |
|
|
13 |
|
|
|
14 |
class GOF_Function(Function): |
|
|
15 |
@staticmethod |
|
|
16 |
def forward(ctx, weight, gaborFilterBank): |
|
|
17 |
ctx.save_for_backward(weight, gaborFilterBank) |
|
|
18 |
output = _C.gof_forward(weight, gaborFilterBank) |
|
|
19 |
return output |
|
|
20 |
|
|
|
21 |
@staticmethod |
|
|
22 |
@once_differentiable |
|
|
23 |
def backward(ctx, grad_output): |
|
|
24 |
weight, gaborFilterBank = ctx.saved_tensors |
|
|
25 |
grad_weight = _C.gof_backward(grad_output, gaborFilterBank) |
|
|
26 |
return grad_weight, None |
|
|
27 |
|
|
|
28 |
class MConv(_ConvNd): |
|
|
29 |
''' |
|
|
30 |
Baee layer class for modulated convolution |
|
|
31 |
''' |
|
|
32 |
def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1, |
|
|
33 |
padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'): |
|
|
34 |
if groups != 1: |
|
|
35 |
raise ValueError('Group-conv not supported!') |
|
|
36 |
kernel_size = (M,) + _pair(kernel_size) |
|
|
37 |
stride = _pair(stride) |
|
|
38 |
padding = _pair(padding) |
|
|
39 |
dilation = _pair(dilation) |
|
|
40 |
super(MConv, self).__init__( |
|
|
41 |
in_channels, out_channels, kernel_size, stride, padding, dilation, |
|
|
42 |
False, _pair(0), groups, bias, padding_mode) |
|
|
43 |
self.expand = expand |
|
|
44 |
self.M = M |
|
|
45 |
self.need_bias = bias |
|
|
46 |
self.generate_MFilters(nScale, kernel_size) |
|
|
47 |
self.GOF_Function = GOF_Function.apply |
|
|
48 |
|
|
|
49 |
def generate_MFilters(self, nScale, kernel_size): |
|
|
50 |
raise NotImplementedError |
|
|
51 |
|
|
|
52 |
def forward(self, x): |
|
|
53 |
if self.expand: |
|
|
54 |
x = self.do_expanding(x) |
|
|
55 |
new_weight = self.GOF_Function(self.weight, self.MFilters) |
|
|
56 |
new_bias = self.expand_bias(self.bias) if self.need_bias else self.bias |
|
|
57 |
if self.padding_mode == 'circular': |
|
|
58 |
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, |
|
|
59 |
(self.padding[0] + 1) // 2, self.padding[0] // 2) |
|
|
60 |
return F.conv2d(F.pad(input, expanded_padding, mode='circular'), |
|
|
61 |
self.weight, self.bias, self.stride, |
|
|
62 |
_pair(0), self.dilation, self.groups) |
|
|
63 |
return F.conv2d(x, new_weight, new_bias, self.stride, |
|
|
64 |
self.padding, self.dilation, self.groups) |
|
|
65 |
|
|
|
66 |
def do_expanding(self, x): |
|
|
67 |
index = [] |
|
|
68 |
for i in range(x.size(1)): |
|
|
69 |
for _ in range(self.M): |
|
|
70 |
index.append(i) |
|
|
71 |
index = torch.LongTensor(index).cuda() if x.is_cuda else torch.LongTensor(index) |
|
|
72 |
return x.index_select(1, index) |
|
|
73 |
|
|
|
74 |
def expand_bias(self, bias): |
|
|
75 |
index = [] |
|
|
76 |
for i in range(bias.size()): |
|
|
77 |
for _ in range(self.M): |
|
|
78 |
index.append(i) |
|
|
79 |
index = torch.LongTensor(index).cuda() if bias.is_cuda else torch.LongTensor(index) |
|
|
80 |
return bias.index_select(0, index) |
|
|
81 |
|
|
|
82 |
class GConv(MConv): |
|
|
83 |
''' |
|
|
84 |
Gabor Convolutional Operation Layer |
|
|
85 |
''' |
|
|
86 |
def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1, |
|
|
87 |
padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'): |
|
|
88 |
super(GConv, self).__init__(in_channels, out_channels, kernel_size, M, nScale, stride, |
|
|
89 |
padding, dilation, groups, bias, expand, padding_mode) |
|
|
90 |
|
|
|
91 |
def generate_MFilters(self, nScale, kernel_size): |
|
|
92 |
# To generate Gabor Filters |
|
|
93 |
self.register_buffer('MFilters', getGaborFilterBank(nScale, *kernel_size)) |
|
|
94 |
|
|
|
95 |
def getGaborFilterBank(nScale, M, h, w): |
|
|
96 |
Kmax = math.pi / 2 |
|
|
97 |
f = math.sqrt(2) |
|
|
98 |
sigma = math.pi |
|
|
99 |
sqsigma = sigma ** 2 |
|
|
100 |
postmean = math.exp(-sqsigma / 2) |
|
|
101 |
if h != 1: |
|
|
102 |
gfilter_real = torch.zeros(M, h, w) |
|
|
103 |
for i in range(M): |
|
|
104 |
theta = i / M * math.pi |
|
|
105 |
k = Kmax / f ** (nScale - 1) |
|
|
106 |
xymax = -1e309 |
|
|
107 |
xymin = 1e309 |
|
|
108 |
for y in range(h): |
|
|
109 |
for x in range(w): |
|
|
110 |
y1 = y + 1 - ((h + 1) / 2) |
|
|
111 |
x1 = x + 1 - ((w + 1) / 2) |
|
|
112 |
tmp1 = math.exp(-(k * k * (x1 * x1 + y1 * y1) / (2 * sqsigma))) |
|
|
113 |
tmp2 = math.cos(k * math.cos(theta) * x1 + k * math.sin(theta) * y1) - postmean # For real part |
|
|
114 |
# tmp3 = math.sin(k*math.cos(theta)*x1+k*math.sin(theta)*y1) # For imaginary part |
|
|
115 |
gfilter_real[i][y][x] = k * k * tmp1 * tmp2 / sqsigma |
|
|
116 |
xymax = max(xymax, gfilter_real[i][y][x]) |
|
|
117 |
xymin = min(xymin, gfilter_real[i][y][x]) |
|
|
118 |
gfilter_real[i] = (gfilter_real[i] - xymin) / (xymax - xymin) |
|
|
119 |
else: |
|
|
120 |
gfilter_real = torch.ones(M, h, w) |
|
|
121 |
return gfilter_real |