[f77492]: / gcn / layers / GConv.py

Download this file

121 lines (110 with data), 4.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import division
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from torch.nn.modules.conv import _ConvNd
from torch.autograd.function import once_differentiable
from gcn import _C
# from gcn import _C
class GOF_Function(Function):
@staticmethod
def forward(ctx, weight, gaborFilterBank):
ctx.save_for_backward(weight, gaborFilterBank)
output = _C.gof_forward(weight, gaborFilterBank)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
weight, gaborFilterBank = ctx.saved_tensors
grad_weight = _C.gof_backward(grad_output, gaborFilterBank)
return grad_weight, None
class MConv(_ConvNd):
'''
Baee layer class for modulated convolution
'''
def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1,
padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'):
if groups != 1:
raise ValueError('Group-conv not supported!')
kernel_size = (M,) + _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(MConv, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
self.expand = expand
self.M = M
self.need_bias = bias
self.generate_MFilters(nScale, kernel_size)
self.GOF_Function = GOF_Function.apply
def generate_MFilters(self, nScale, kernel_size):
raise NotImplementedError
def forward(self, x):
if self.expand:
x = self.do_expanding(x)
new_weight = self.GOF_Function(self.weight, self.MFilters)
new_bias = self.expand_bias(self.bias) if self.need_bias else self.bias
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
self.weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(x, new_weight, new_bias, self.stride,
self.padding, self.dilation, self.groups)
def do_expanding(self, x):
index = []
for i in range(x.size(1)):
for _ in range(self.M):
index.append(i)
index = torch.LongTensor(index).cuda() if x.is_cuda else torch.LongTensor(index)
return x.index_select(1, index)
def expand_bias(self, bias):
index = []
for i in range(bias.size()):
for _ in range(self.M):
index.append(i)
index = torch.LongTensor(index).cuda() if bias.is_cuda else torch.LongTensor(index)
return bias.index_select(0, index)
class GConv(MConv):
'''
Gabor Convolutional Operation Layer
'''
def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1,
padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'):
super(GConv, self).__init__(in_channels, out_channels, kernel_size, M, nScale, stride,
padding, dilation, groups, bias, expand, padding_mode)
def generate_MFilters(self, nScale, kernel_size):
# To generate Gabor Filters
self.register_buffer('MFilters', getGaborFilterBank(nScale, *kernel_size))
def getGaborFilterBank(nScale, M, h, w):
Kmax = math.pi / 2
f = math.sqrt(2)
sigma = math.pi
sqsigma = sigma ** 2
postmean = math.exp(-sqsigma / 2)
if h != 1:
gfilter_real = torch.zeros(M, h, w)
for i in range(M):
theta = i / M * math.pi
k = Kmax / f ** (nScale - 1)
xymax = -1e309
xymin = 1e309
for y in range(h):
for x in range(w):
y1 = y + 1 - ((h + 1) / 2)
x1 = x + 1 - ((w + 1) / 2)
tmp1 = math.exp(-(k * k * (x1 * x1 + y1 * y1) / (2 * sqsigma)))
tmp2 = math.cos(k * math.cos(theta) * x1 + k * math.sin(theta) * y1) - postmean # For real part
# tmp3 = math.sin(k*math.cos(theta)*x1+k*math.sin(theta)*y1) # For imaginary part
gfilter_real[i][y][x] = k * k * tmp1 * tmp2 / sqsigma
xymax = max(xymax, gfilter_real[i][y][x])
xymin = min(xymin, gfilter_real[i][y][x])
gfilter_real[i] = (gfilter_real[i] - xymin) / (xymax - xymin)
else:
gfilter_real = torch.ones(M, h, w)
return gfilter_real