|
a |
|
b/CellGraph/pixelcnn.py |
|
|
1 |
import torch.nn as nn |
|
|
2 |
from layers_custom import maskConv0, MaskConvBlock |
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
class MaskCNN(nn.Module): |
|
|
6 |
def __init__(self, n_channel=1024, h=128): |
|
|
7 |
"""PixelCNN Model""" |
|
|
8 |
super(MaskCNN, self).__init__() |
|
|
9 |
|
|
|
10 |
self.MaskConv0 = maskConv0(n_channel, h, k_size=7, stride=1, pad=3) |
|
|
11 |
# large 7 x 7 masked filter with image downshift to ensure that each output neuron's receptive field only sees what is above it in the image |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
MaskConv = [] |
|
|
15 |
|
|
|
16 |
# stack of 10 gated residual masked conv blocks |
|
|
17 |
for i in range(10): |
|
|
18 |
MaskConv.append(MaskConvBlock(h, k_size=3, stride=1, pad=1)) |
|
|
19 |
self.MaskConv = nn.Sequential(*MaskConv) |
|
|
20 |
|
|
|
21 |
# 1x1 conv to upsample to required feature (channel) length |
|
|
22 |
|
|
|
23 |
self.out = nn.Sequential( |
|
|
24 |
nn.ReLU(), |
|
|
25 |
nn.Conv2d(h, n_channel, kernel_size=1, stride=1, padding=0), |
|
|
26 |
nn.BatchNorm2d(n_channel), |
|
|
27 |
nn.ReLU() |
|
|
28 |
) |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def forward(self, x): |
|
|
32 |
""" |
|
|
33 |
Args: |
|
|
34 |
x: [batch_size, channel, height, width] |
|
|
35 |
Return: |
|
|
36 |
out [batch_size, channel, height, width] |
|
|
37 |
""" |
|
|
38 |
# fully convolutional, feature map dimension maintained constant throughout |
|
|
39 |
x = self.MaskConv0(x) |
|
|
40 |
|
|
|
41 |
x = self.MaskConv(x) |
|
|
42 |
|
|
|
43 |
x = self.out(x) |
|
|
44 |
|
|
|
45 |
return x |
|
|
46 |
|
|
|
47 |
if __name__ == '__main__': |
|
|
48 |
from torchsummary import summary |
|
|
49 |
model = PixelCNN(1024, 128) |
|
|
50 |
summary(model, (1024, 7,7)) |
|
|
51 |
x = torch.rand(2, 1024, 7, 7) |
|
|
52 |
x = model(x) |
|
|
53 |
print(x.shape) |
|
|
54 |
|