Diff of /CellGraph/pixelcnn.py [000000] .. [2095ed]

Switch to side-by-side view

--- a
+++ b/CellGraph/pixelcnn.py
@@ -0,0 +1,54 @@
+import torch.nn as nn
+from layers_custom import maskConv0, MaskConvBlock
+import torch
+
+class MaskCNN(nn.Module):
+    def __init__(self, n_channel=1024, h=128):
+        """PixelCNN Model"""
+        super(MaskCNN, self).__init__()
+
+        self.MaskConv0 = maskConv0(n_channel, h, k_size=7, stride=1, pad=3)
+        # large 7 x 7 masked filter with image downshift to ensure that each output neuron's receptive field only sees what is above it in the image 
+
+
+        MaskConv = []
+        
+        # stack of 10 gated residual masked conv blocks
+        for i in range(10):
+            MaskConv.append(MaskConvBlock(h, k_size=3, stride=1, pad=1))
+        self.MaskConv = nn.Sequential(*MaskConv)
+
+        # 1x1 conv to upsample to required feature (channel) length
+
+        self.out = nn.Sequential(
+            nn.ReLU(),
+            nn.Conv2d(h, n_channel, kernel_size=1, stride=1, padding=0),
+            nn.BatchNorm2d(n_channel),
+            nn.ReLU()
+            )
+
+
+    def forward(self, x):
+        """
+        Args:
+            x: [batch_size, channel, height, width]
+        Return:
+            out [batch_size, channel, height, width]
+        """
+        # fully convolutional, feature map dimension maintained constant throughout
+        x = self.MaskConv0(x)
+
+        x = self.MaskConv(x)
+
+        x = self.out(x)
+
+        return x
+
+if __name__ == '__main__':
+    from torchsummary import summary
+    model = PixelCNN(1024, 128)
+    summary(model, (1024, 7,7))
+    x = torch.rand(2, 1024, 7, 7)
+    x = model(x)
+    print(x.shape)
+