Diff of /CellGraph/pixelcnn.py [000000] .. [2095ed]

Switch to unified view

a b/CellGraph/pixelcnn.py
1
import torch.nn as nn
2
from layers_custom import maskConv0, MaskConvBlock
3
import torch
4
5
class MaskCNN(nn.Module):
6
    def __init__(self, n_channel=1024, h=128):
7
        """PixelCNN Model"""
8
        super(MaskCNN, self).__init__()
9
10
        self.MaskConv0 = maskConv0(n_channel, h, k_size=7, stride=1, pad=3)
11
        # large 7 x 7 masked filter with image downshift to ensure that each output neuron's receptive field only sees what is above it in the image 
12
13
14
        MaskConv = []
15
        
16
        # stack of 10 gated residual masked conv blocks
17
        for i in range(10):
18
            MaskConv.append(MaskConvBlock(h, k_size=3, stride=1, pad=1))
19
        self.MaskConv = nn.Sequential(*MaskConv)
20
21
        # 1x1 conv to upsample to required feature (channel) length
22
23
        self.out = nn.Sequential(
24
            nn.ReLU(),
25
            nn.Conv2d(h, n_channel, kernel_size=1, stride=1, padding=0),
26
            nn.BatchNorm2d(n_channel),
27
            nn.ReLU()
28
            )
29
30
31
    def forward(self, x):
32
        """
33
        Args:
34
            x: [batch_size, channel, height, width]
35
        Return:
36
            out [batch_size, channel, height, width]
37
        """
38
        # fully convolutional, feature map dimension maintained constant throughout
39
        x = self.MaskConv0(x)
40
41
        x = self.MaskConv(x)
42
43
        x = self.out(x)
44
45
        return x
46
47
if __name__ == '__main__':
48
    from torchsummary import summary
49
    model = PixelCNN(1024, 128)
50
    summary(model, (1024, 7,7))
51
    x = torch.rand(2, 1024, 7, 7)
52
    x = model(x)
53
    print(x.shape)
54