a b/CellGraph/layers_custom.py
1
import torch
2
import torch.nn as nn
3
import pdb
4
5
def down_shift(x, pad=None):
6
    # Pytorch ordering
7
    xs = [int(y) for y in x.size()]
8
    # when downshifting, the last row is removed 
9
    x = x[:, :, :xs[2] - 1, :]
10
    # padding left, padding right, padding top, padding bottom
11
    pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad
12
    return pad(x)
13
14
class MaskedConv2d(nn.Conv2d):
15
    def __init__(self, c_in, c_out, k_size, stride, pad, use_down_shift=False):
16
        super(MaskedConv2d, self).__init__(
17
            c_in, c_out, k_size, stride, pad, bias=False)
18
19
        ch_out, ch_in, height, width = self.weight.size()
20
21
                # Mask
22
        #         -------------------------------------
23
        #        |  1       1       1       1       1 |
24
        #        |  1       1       1       1       1 |
25
        #        |  1       1       1       1       1 |   H // 2
26
        #        |  0       0       0       0       0 |   H // 2 + 1
27
        #        |  0       0       0       0       0 |
28
        #         -------------------------------------
29
        #  index    0       1     W//2    W//2+1
30
        
31
        mask = torch.ones(ch_out, ch_in, height, width)
32
        mask[:, :, height // 2 + 1:] = 0
33
        self.register_buffer('mask', mask)
34
        self.use_down_shift = use_down_shift
35
36
    def forward(self, x):
37
        self.weight.data *= self.mask
38
        if self.use_down_shift:
39
            x = down_shift(x)
40
        return super(MaskedConv2d, self).forward(x)
41
42
def maskConv0(c_in=3, c_out=256, k_size=7, stride=1, pad=3):
43
    """2D Masked Convolution first layer"""
44
    return nn.Sequential(
45
        MaskedConv2d(c_in, c_out * 2, k_size, stride, pad, use_down_shift=True),
46
        nn.BatchNorm2d(c_out * 2),
47
        Gate()
48
        )
49
50
class Gate(nn.Module):
51
    def __init__(self):
52
        super(Gate, self).__init__()
53
54
    def forward(self, x):
55
        # gated activation 
56
        xf, xg = torch.chunk(x, 2, dim=1)
57
        f = torch.tanh(xf)
58
        g = torch.sigmoid(xg)
59
        return f * g
60
61
class MaskConvBlock(nn.Module):
62
    def __init__(self, h=128, k_size=3, stride=1, pad=1):
63
        """1x1 Conv + 2D Masked Convolution (type B) + 1x1 Conv"""
64
        super(MaskConvBlock, self).__init__()
65
66
        self.net = nn.Sequential(
67
            MaskedConv2d(h, 2 * h, k_size, stride, pad),
68
            nn.BatchNorm2d(2 * h),
69
            Gate()
70
        )
71
72
    def forward(self, x):
73
        """Residual connection"""
74
        return self.net(x) + x
75
76
if __name__ == '__main__':
77
    def conv(x, kernel):
78
        return nn.functional.conv2d(x, kernel, padding=1)
79
    x = torch.ones((1, 1, 5, 5)) * 0.1
80
    x[:,:,1,0] = 1000
81
    
82
    print("blindspot experiment")
83
    normal_kernel = torch.ones(1, 1, 3, 3)
84
    mask_kernel = torch.zeros(1, 1, 3, 3)
85
    mask_kernel[:,:,0,:] = 1
86
    mask_b = mask_kernel.clone()
87
    mask_b[:,:,1,1] = 1
88
    # mask_kernel[:,:,1,1] = 1
89
90
    print("unmasked kernel:", "\n",normal_kernel.squeeze(), "\n")
91
    print("masked kernel:", "\n", mask_kernel.squeeze(), "\n")
92
93
    print("normal conv")
94
    print("orig image", "\n", x.squeeze(), "\n")
95
96
    y = conv(x, normal_kernel)
97
    print(y[:,0, :,:], "\n")
98
99
    y = conv(y, normal_kernel)
100
    print(y[:,0, :,:], "\n")
101
102
    print("with mask")
103
    print("orig image", "\n", x.squeeze(), "\n")
104
105
    y = conv(x, mask_kernel)
106
    print(y[:,0, :,:], "\n")
107
    
108
    y = conv(y, mask_b)
109
    print(y[:,0, :,:], "\n")
110
111
    y = conv(y, mask_b)
112
    print(y[:,0, :,:],"\n")
113
114
    print("with down_shift")
115
    print("orig image", x.squeeze(), "\n")
116
    c_kernel = mask_kernel
117
    c_kernel[:,:,1,:] = 1
118
119
    print("custom kernel:", "\n", c_kernel.squeeze(), "\n")
120
    y = conv(down_shift(x), c_kernel)
121
    print(y[:,0, :,:],"\n")
122
    y = conv(y, c_kernel)
123
    print(y[:,0, :,:],"\n")
124
    y = conv(y, c_kernel)
125
    print(y[:,0, :,:],"\n")
126
    y = conv(y, c_kernel)
127
    print(y[:,0, :,:],"\n")
128
129
130
131
132