Diff of /gcn/layers/GConv.py [000000] .. [f77492]

Switch to unified view

a b/gcn/layers/GConv.py
1
from __future__ import division
2
import math
3
import torch
4
from torch import nn
5
import torch.nn.functional as F
6
from torch.autograd import Function
7
from torch.nn.modules.utils import _pair
8
from torch.nn.modules.conv import _ConvNd
9
from torch.autograd.function import once_differentiable
10
11
from gcn import _C
12
# from gcn import _C
13
14
class GOF_Function(Function):
15
    @staticmethod
16
    def forward(ctx, weight, gaborFilterBank):
17
        ctx.save_for_backward(weight, gaborFilterBank)
18
        output = _C.gof_forward(weight, gaborFilterBank)
19
        return output
20
21
    @staticmethod
22
    @once_differentiable
23
    def backward(ctx, grad_output):
24
        weight, gaborFilterBank = ctx.saved_tensors
25
        grad_weight = _C.gof_backward(grad_output, gaborFilterBank)
26
        return grad_weight, None 
27
28
class MConv(_ConvNd):
29
    '''
30
    Baee layer class for modulated convolution
31
    '''
32
    def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1,
33
                    padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'):
34
        if groups != 1:
35
            raise ValueError('Group-conv not supported!')
36
        kernel_size = (M,) + _pair(kernel_size)
37
        stride = _pair(stride)
38
        padding = _pair(padding)
39
        dilation = _pair(dilation)
40
        super(MConv, self).__init__(
41
            in_channels, out_channels, kernel_size, stride, padding, dilation,
42
            False, _pair(0), groups, bias, padding_mode)
43
        self.expand = expand
44
        self.M = M
45
        self.need_bias = bias
46
        self.generate_MFilters(nScale, kernel_size)
47
        self.GOF_Function = GOF_Function.apply
48
49
    def generate_MFilters(self, nScale, kernel_size):
50
        raise NotImplementedError
51
52
    def forward(self, x):
53
        if self.expand:
54
            x = self.do_expanding(x)
55
        new_weight = self.GOF_Function(self.weight, self.MFilters)
56
        new_bias = self.expand_bias(self.bias) if self.need_bias else self.bias
57
        if self.padding_mode == 'circular':
58
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
59
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
60
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
61
                            self.weight, self.bias, self.stride,
62
                            _pair(0), self.dilation, self.groups)
63
        return F.conv2d(x, new_weight, new_bias, self.stride,
64
                self.padding, self.dilation, self.groups)
65
66
    def do_expanding(self, x):
67
        index = []
68
        for i in range(x.size(1)):
69
            for _ in range(self.M):
70
                index.append(i)
71
        index = torch.LongTensor(index).cuda() if x.is_cuda else torch.LongTensor(index)
72
        return x.index_select(1, index)
73
    
74
    def expand_bias(self, bias):
75
        index = []
76
        for i in range(bias.size()):
77
            for _ in range(self.M):
78
                index.append(i)
79
        index = torch.LongTensor(index).cuda() if bias.is_cuda else torch.LongTensor(index)
80
        return bias.index_select(0, index)
81
82
class GConv(MConv):
83
    '''
84
    Gabor Convolutional Operation Layer
85
    '''
86
    def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1,
87
                    padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'):
88
        super(GConv, self).__init__(in_channels, out_channels, kernel_size, M, nScale, stride,
89
                    padding, dilation, groups, bias, expand, padding_mode)
90
91
    def generate_MFilters(self, nScale, kernel_size):
92
        # To generate Gabor Filters
93
        self.register_buffer('MFilters', getGaborFilterBank(nScale, *kernel_size))
94
95
def getGaborFilterBank(nScale, M, h, w):
96
    Kmax = math.pi / 2
97
    f = math.sqrt(2)
98
    sigma = math.pi
99
    sqsigma = sigma ** 2
100
    postmean = math.exp(-sqsigma / 2)
101
    if h != 1:
102
        gfilter_real = torch.zeros(M, h, w)
103
        for i in range(M):
104
            theta = i / M * math.pi
105
            k = Kmax / f ** (nScale - 1)
106
            xymax = -1e309
107
            xymin = 1e309
108
            for y in range(h):
109
                for x in range(w):
110
                    y1 = y + 1 - ((h + 1) / 2)
111
                    x1 = x + 1 - ((w + 1) / 2)
112
                    tmp1 = math.exp(-(k * k * (x1 * x1 + y1 * y1) / (2 * sqsigma)))
113
                    tmp2 = math.cos(k * math.cos(theta) * x1 + k * math.sin(theta) * y1) - postmean # For real part
114
                    # tmp3 = math.sin(k*math.cos(theta)*x1+k*math.sin(theta)*y1) # For imaginary part
115
                    gfilter_real[i][y][x] = k * k * tmp1 * tmp2 / sqsigma           
116
                    xymax = max(xymax, gfilter_real[i][y][x])
117
                    xymin = min(xymin, gfilter_real[i][y][x])
118
            gfilter_real[i] = (gfilter_real[i] - xymin) / (xymax - xymin)
119
    else:
120
        gfilter_real = torch.ones(M, h, w)
121
    return gfilter_real