--- a +++ b/CellGraph/layers_custom.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import pdb + +def down_shift(x, pad=None): + # Pytorch ordering + xs = [int(y) for y in x.size()] + # when downshifting, the last row is removed + x = x[:, :, :xs[2] - 1, :] + # padding left, padding right, padding top, padding bottom + pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad + return pad(x) + +class MaskedConv2d(nn.Conv2d): + def __init__(self, c_in, c_out, k_size, stride, pad, use_down_shift=False): + super(MaskedConv2d, self).__init__( + c_in, c_out, k_size, stride, pad, bias=False) + + ch_out, ch_in, height, width = self.weight.size() + + # Mask + # ------------------------------------- + # | 1 1 1 1 1 | + # | 1 1 1 1 1 | + # | 1 1 1 1 1 | H // 2 + # | 0 0 0 0 0 | H // 2 + 1 + # | 0 0 0 0 0 | + # ------------------------------------- + # index 0 1 W//2 W//2+1 + + mask = torch.ones(ch_out, ch_in, height, width) + mask[:, :, height // 2 + 1:] = 0 + self.register_buffer('mask', mask) + self.use_down_shift = use_down_shift + + def forward(self, x): + self.weight.data *= self.mask + if self.use_down_shift: + x = down_shift(x) + return super(MaskedConv2d, self).forward(x) + +def maskConv0(c_in=3, c_out=256, k_size=7, stride=1, pad=3): + """2D Masked Convolution first layer""" + return nn.Sequential( + MaskedConv2d(c_in, c_out * 2, k_size, stride, pad, use_down_shift=True), + nn.BatchNorm2d(c_out * 2), + Gate() + ) + +class Gate(nn.Module): + def __init__(self): + super(Gate, self).__init__() + + def forward(self, x): + # gated activation + xf, xg = torch.chunk(x, 2, dim=1) + f = torch.tanh(xf) + g = torch.sigmoid(xg) + return f * g + +class MaskConvBlock(nn.Module): + def __init__(self, h=128, k_size=3, stride=1, pad=1): + """1x1 Conv + 2D Masked Convolution (type B) + 1x1 Conv""" + super(MaskConvBlock, self).__init__() + + self.net = nn.Sequential( + MaskedConv2d(h, 2 * h, k_size, stride, pad), + nn.BatchNorm2d(2 * h), + Gate() + ) + + def forward(self, x): + """Residual connection""" + return self.net(x) + x + +if __name__ == '__main__': + def conv(x, kernel): + return nn.functional.conv2d(x, kernel, padding=1) + x = torch.ones((1, 1, 5, 5)) * 0.1 + x[:,:,1,0] = 1000 + + print("blindspot experiment") + normal_kernel = torch.ones(1, 1, 3, 3) + mask_kernel = torch.zeros(1, 1, 3, 3) + mask_kernel[:,:,0,:] = 1 + mask_b = mask_kernel.clone() + mask_b[:,:,1,1] = 1 + # mask_kernel[:,:,1,1] = 1 + + print("unmasked kernel:", "\n",normal_kernel.squeeze(), "\n") + print("masked kernel:", "\n", mask_kernel.squeeze(), "\n") + + print("normal conv") + print("orig image", "\n", x.squeeze(), "\n") + + y = conv(x, normal_kernel) + print(y[:,0, :,:], "\n") + + y = conv(y, normal_kernel) + print(y[:,0, :,:], "\n") + + print("with mask") + print("orig image", "\n", x.squeeze(), "\n") + + y = conv(x, mask_kernel) + print(y[:,0, :,:], "\n") + + y = conv(y, mask_b) + print(y[:,0, :,:], "\n") + + y = conv(y, mask_b) + print(y[:,0, :,:],"\n") + + print("with down_shift") + print("orig image", x.squeeze(), "\n") + c_kernel = mask_kernel + c_kernel[:,:,1,:] = 1 + + print("custom kernel:", "\n", c_kernel.squeeze(), "\n") + y = conv(down_shift(x), c_kernel) + print(y[:,0, :,:],"\n") + y = conv(y, c_kernel) + print(y[:,0, :,:],"\n") + y = conv(y, c_kernel) + print(y[:,0, :,:],"\n") + y = conv(y, c_kernel) + print(y[:,0, :,:],"\n") + + + + +