|
a |
|
b/CellGraph/layers_custom.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import pdb |
|
|
4 |
|
|
|
5 |
def down_shift(x, pad=None): |
|
|
6 |
# Pytorch ordering |
|
|
7 |
xs = [int(y) for y in x.size()] |
|
|
8 |
# when downshifting, the last row is removed |
|
|
9 |
x = x[:, :, :xs[2] - 1, :] |
|
|
10 |
# padding left, padding right, padding top, padding bottom |
|
|
11 |
pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad |
|
|
12 |
return pad(x) |
|
|
13 |
|
|
|
14 |
class MaskedConv2d(nn.Conv2d): |
|
|
15 |
def __init__(self, c_in, c_out, k_size, stride, pad, use_down_shift=False): |
|
|
16 |
super(MaskedConv2d, self).__init__( |
|
|
17 |
c_in, c_out, k_size, stride, pad, bias=False) |
|
|
18 |
|
|
|
19 |
ch_out, ch_in, height, width = self.weight.size() |
|
|
20 |
|
|
|
21 |
# Mask |
|
|
22 |
# ------------------------------------- |
|
|
23 |
# | 1 1 1 1 1 | |
|
|
24 |
# | 1 1 1 1 1 | |
|
|
25 |
# | 1 1 1 1 1 | H // 2 |
|
|
26 |
# | 0 0 0 0 0 | H // 2 + 1 |
|
|
27 |
# | 0 0 0 0 0 | |
|
|
28 |
# ------------------------------------- |
|
|
29 |
# index 0 1 W//2 W//2+1 |
|
|
30 |
|
|
|
31 |
mask = torch.ones(ch_out, ch_in, height, width) |
|
|
32 |
mask[:, :, height // 2 + 1:] = 0 |
|
|
33 |
self.register_buffer('mask', mask) |
|
|
34 |
self.use_down_shift = use_down_shift |
|
|
35 |
|
|
|
36 |
def forward(self, x): |
|
|
37 |
self.weight.data *= self.mask |
|
|
38 |
if self.use_down_shift: |
|
|
39 |
x = down_shift(x) |
|
|
40 |
return super(MaskedConv2d, self).forward(x) |
|
|
41 |
|
|
|
42 |
def maskConv0(c_in=3, c_out=256, k_size=7, stride=1, pad=3): |
|
|
43 |
"""2D Masked Convolution first layer""" |
|
|
44 |
return nn.Sequential( |
|
|
45 |
MaskedConv2d(c_in, c_out * 2, k_size, stride, pad, use_down_shift=True), |
|
|
46 |
nn.BatchNorm2d(c_out * 2), |
|
|
47 |
Gate() |
|
|
48 |
) |
|
|
49 |
|
|
|
50 |
class Gate(nn.Module): |
|
|
51 |
def __init__(self): |
|
|
52 |
super(Gate, self).__init__() |
|
|
53 |
|
|
|
54 |
def forward(self, x): |
|
|
55 |
# gated activation |
|
|
56 |
xf, xg = torch.chunk(x, 2, dim=1) |
|
|
57 |
f = torch.tanh(xf) |
|
|
58 |
g = torch.sigmoid(xg) |
|
|
59 |
return f * g |
|
|
60 |
|
|
|
61 |
class MaskConvBlock(nn.Module): |
|
|
62 |
def __init__(self, h=128, k_size=3, stride=1, pad=1): |
|
|
63 |
"""1x1 Conv + 2D Masked Convolution (type B) + 1x1 Conv""" |
|
|
64 |
super(MaskConvBlock, self).__init__() |
|
|
65 |
|
|
|
66 |
self.net = nn.Sequential( |
|
|
67 |
MaskedConv2d(h, 2 * h, k_size, stride, pad), |
|
|
68 |
nn.BatchNorm2d(2 * h), |
|
|
69 |
Gate() |
|
|
70 |
) |
|
|
71 |
|
|
|
72 |
def forward(self, x): |
|
|
73 |
"""Residual connection""" |
|
|
74 |
return self.net(x) + x |
|
|
75 |
|
|
|
76 |
if __name__ == '__main__': |
|
|
77 |
def conv(x, kernel): |
|
|
78 |
return nn.functional.conv2d(x, kernel, padding=1) |
|
|
79 |
x = torch.ones((1, 1, 5, 5)) * 0.1 |
|
|
80 |
x[:,:,1,0] = 1000 |
|
|
81 |
|
|
|
82 |
print("blindspot experiment") |
|
|
83 |
normal_kernel = torch.ones(1, 1, 3, 3) |
|
|
84 |
mask_kernel = torch.zeros(1, 1, 3, 3) |
|
|
85 |
mask_kernel[:,:,0,:] = 1 |
|
|
86 |
mask_b = mask_kernel.clone() |
|
|
87 |
mask_b[:,:,1,1] = 1 |
|
|
88 |
# mask_kernel[:,:,1,1] = 1 |
|
|
89 |
|
|
|
90 |
print("unmasked kernel:", "\n",normal_kernel.squeeze(), "\n") |
|
|
91 |
print("masked kernel:", "\n", mask_kernel.squeeze(), "\n") |
|
|
92 |
|
|
|
93 |
print("normal conv") |
|
|
94 |
print("orig image", "\n", x.squeeze(), "\n") |
|
|
95 |
|
|
|
96 |
y = conv(x, normal_kernel) |
|
|
97 |
print(y[:,0, :,:], "\n") |
|
|
98 |
|
|
|
99 |
y = conv(y, normal_kernel) |
|
|
100 |
print(y[:,0, :,:], "\n") |
|
|
101 |
|
|
|
102 |
print("with mask") |
|
|
103 |
print("orig image", "\n", x.squeeze(), "\n") |
|
|
104 |
|
|
|
105 |
y = conv(x, mask_kernel) |
|
|
106 |
print(y[:,0, :,:], "\n") |
|
|
107 |
|
|
|
108 |
y = conv(y, mask_b) |
|
|
109 |
print(y[:,0, :,:], "\n") |
|
|
110 |
|
|
|
111 |
y = conv(y, mask_b) |
|
|
112 |
print(y[:,0, :,:],"\n") |
|
|
113 |
|
|
|
114 |
print("with down_shift") |
|
|
115 |
print("orig image", x.squeeze(), "\n") |
|
|
116 |
c_kernel = mask_kernel |
|
|
117 |
c_kernel[:,:,1,:] = 1 |
|
|
118 |
|
|
|
119 |
print("custom kernel:", "\n", c_kernel.squeeze(), "\n") |
|
|
120 |
y = conv(down_shift(x), c_kernel) |
|
|
121 |
print(y[:,0, :,:],"\n") |
|
|
122 |
y = conv(y, c_kernel) |
|
|
123 |
print(y[:,0, :,:],"\n") |
|
|
124 |
y = conv(y, c_kernel) |
|
|
125 |
print(y[:,0, :,:],"\n") |
|
|
126 |
y = conv(y, c_kernel) |
|
|
127 |
print(y[:,0, :,:],"\n") |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
|
|
|
131 |
|
|
|
132 |
|