Switch to side-by-side view

--- a
+++ b/CellGraph/layers_custom.py
@@ -0,0 +1,132 @@
+import torch
+import torch.nn as nn
+import pdb
+
+def down_shift(x, pad=None):
+    # Pytorch ordering
+    xs = [int(y) for y in x.size()]
+    # when downshifting, the last row is removed 
+    x = x[:, :, :xs[2] - 1, :]
+    # padding left, padding right, padding top, padding bottom
+    pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad
+    return pad(x)
+
+class MaskedConv2d(nn.Conv2d):
+    def __init__(self, c_in, c_out, k_size, stride, pad, use_down_shift=False):
+        super(MaskedConv2d, self).__init__(
+            c_in, c_out, k_size, stride, pad, bias=False)
+
+        ch_out, ch_in, height, width = self.weight.size()
+
+                # Mask
+        #         -------------------------------------
+        #        |  1       1       1       1       1 |
+        #        |  1       1       1       1       1 |
+        #        |  1       1       1       1       1 |   H // 2
+        #        |  0       0       0       0       0 |   H // 2 + 1
+        #        |  0       0       0       0       0 |
+        #         -------------------------------------
+        #  index    0       1     W//2    W//2+1
+        
+        mask = torch.ones(ch_out, ch_in, height, width)
+        mask[:, :, height // 2 + 1:] = 0
+        self.register_buffer('mask', mask)
+        self.use_down_shift = use_down_shift
+
+    def forward(self, x):
+        self.weight.data *= self.mask
+        if self.use_down_shift:
+            x = down_shift(x)
+        return super(MaskedConv2d, self).forward(x)
+
+def maskConv0(c_in=3, c_out=256, k_size=7, stride=1, pad=3):
+    """2D Masked Convolution first layer"""
+    return nn.Sequential(
+        MaskedConv2d(c_in, c_out * 2, k_size, stride, pad, use_down_shift=True),
+        nn.BatchNorm2d(c_out * 2),
+        Gate()
+        )
+
+class Gate(nn.Module):
+    def __init__(self):
+        super(Gate, self).__init__()
+
+    def forward(self, x):
+        # gated activation 
+        xf, xg = torch.chunk(x, 2, dim=1)
+        f = torch.tanh(xf)
+        g = torch.sigmoid(xg)
+        return f * g
+
+class MaskConvBlock(nn.Module):
+    def __init__(self, h=128, k_size=3, stride=1, pad=1):
+        """1x1 Conv + 2D Masked Convolution (type B) + 1x1 Conv"""
+        super(MaskConvBlock, self).__init__()
+
+        self.net = nn.Sequential(
+            MaskedConv2d(h, 2 * h, k_size, stride, pad),
+            nn.BatchNorm2d(2 * h),
+            Gate()
+        )
+
+    def forward(self, x):
+        """Residual connection"""
+        return self.net(x) + x
+
+if __name__ == '__main__':
+    def conv(x, kernel):
+        return nn.functional.conv2d(x, kernel, padding=1)
+    x = torch.ones((1, 1, 5, 5)) * 0.1
+    x[:,:,1,0] = 1000
+    
+    print("blindspot experiment")
+    normal_kernel = torch.ones(1, 1, 3, 3)
+    mask_kernel = torch.zeros(1, 1, 3, 3)
+    mask_kernel[:,:,0,:] = 1
+    mask_b = mask_kernel.clone()
+    mask_b[:,:,1,1] = 1
+    # mask_kernel[:,:,1,1] = 1
+
+    print("unmasked kernel:", "\n",normal_kernel.squeeze(), "\n")
+    print("masked kernel:", "\n", mask_kernel.squeeze(), "\n")
+
+    print("normal conv")
+    print("orig image", "\n", x.squeeze(), "\n")
+
+    y = conv(x, normal_kernel)
+    print(y[:,0, :,:], "\n")
+
+    y = conv(y, normal_kernel)
+    print(y[:,0, :,:], "\n")
+
+    print("with mask")
+    print("orig image", "\n", x.squeeze(), "\n")
+
+    y = conv(x, mask_kernel)
+    print(y[:,0, :,:], "\n")
+    
+    y = conv(y, mask_b)
+    print(y[:,0, :,:], "\n")
+
+    y = conv(y, mask_b)
+    print(y[:,0, :,:],"\n")
+
+    print("with down_shift")
+    print("orig image", x.squeeze(), "\n")
+    c_kernel = mask_kernel
+    c_kernel[:,:,1,:] = 1
+
+    print("custom kernel:", "\n", c_kernel.squeeze(), "\n")
+    y = conv(down_shift(x), c_kernel)
+    print(y[:,0, :,:],"\n")
+    y = conv(y, c_kernel)
+    print(y[:,0, :,:],"\n")
+    y = conv(y, c_kernel)
+    print(y[:,0, :,:],"\n")
+    y = conv(y, c_kernel)
+    print(y[:,0, :,:],"\n")
+
+
+
+
+