|
a |
|
b/model.py |
|
|
1 |
""" |
|
|
2 |
UNet |
|
|
3 |
The main UNet model implementation |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
import torch |
|
|
7 |
import torch.nn as nn |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
# Utility Functions |
|
|
12 |
''' when filter kernel= 3x3, padding=1 makes in&out matrix same size''' |
|
|
13 |
def conv_bn_leru(in_channels, out_channels, kernel_size=3, stride=1, padding=1): |
|
|
14 |
return nn.Sequential( |
|
|
15 |
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), |
|
|
16 |
nn.BatchNorm2d(out_channels), |
|
|
17 |
nn.ReLU(inplace=True), |
|
|
18 |
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), |
|
|
19 |
nn.BatchNorm2d(out_channels), |
|
|
20 |
nn.ReLU(inplace=True), |
|
|
21 |
) |
|
|
22 |
|
|
|
23 |
def down_pooling(): |
|
|
24 |
return nn.MaxPool2d(2) |
|
|
25 |
|
|
|
26 |
def up_pooling(in_channels, out_channels, kernel_size=2, stride=2): |
|
|
27 |
return nn.Sequential( |
|
|
28 |
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), |
|
|
29 |
nn.BatchNorm2d(out_channels), |
|
|
30 |
nn.ReLU(inplace=True) |
|
|
31 |
) |
|
|
32 |
|
|
|
33 |
# UNet class |
|
|
34 |
|
|
|
35 |
class UNet(nn.Module): |
|
|
36 |
def __init__(self, input_channels, nclasses): |
|
|
37 |
super().__init__() |
|
|
38 |
# go down |
|
|
39 |
self.conv1 = conv_bn_leru(input_channels,64) |
|
|
40 |
self.conv2 = conv_bn_leru(64, 128) |
|
|
41 |
self.conv3 = conv_bn_leru(128, 256) |
|
|
42 |
self.conv4 = conv_bn_leru(256, 512) |
|
|
43 |
self.conv5 = conv_bn_leru(512, 1024) |
|
|
44 |
self.down_pooling = nn.MaxPool2d(2) |
|
|
45 |
|
|
|
46 |
# go up |
|
|
47 |
self.up_pool6 = up_pooling(1024, 512) |
|
|
48 |
self.conv6 = conv_bn_leru(1024, 512) |
|
|
49 |
self.up_pool7 = up_pooling(512, 256) |
|
|
50 |
self.conv7 = conv_bn_leru(512, 256) |
|
|
51 |
self.up_pool8 = up_pooling(256, 128) |
|
|
52 |
self.conv8 = conv_bn_leru(256, 128) |
|
|
53 |
self.up_pool9 = up_pooling(128, 64) |
|
|
54 |
self.conv9 = conv_bn_leru(128, 64) |
|
|
55 |
|
|
|
56 |
self.conv10 = nn.Conv2d(64, nclasses, 1) |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
# test weight init |
|
|
60 |
for m in self.modules(): |
|
|
61 |
if isinstance(m, nn.Conv2d): |
|
|
62 |
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_out') |
|
|
63 |
if m.bias is not None: |
|
|
64 |
m.bias.data.zero_() |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def forward(self, x): |
|
|
68 |
# go down |
|
|
69 |
x1 = self.conv1(x) |
|
|
70 |
p1 = self.down_pooling(x1) |
|
|
71 |
x2 = self.conv2(p1) |
|
|
72 |
p2 = self.down_pooling(x2) |
|
|
73 |
x3 = self.conv3(p2) |
|
|
74 |
p3 = self.down_pooling(x3) |
|
|
75 |
x4 = self.conv4(p3) |
|
|
76 |
p4 = self.down_pooling(x4) |
|
|
77 |
x5 = self.conv5(p4) |
|
|
78 |
|
|
|
79 |
# go up |
|
|
80 |
p6 = self.up_pool6(x5) |
|
|
81 |
x6 = torch.cat([p6, x4], dim=1) |
|
|
82 |
x6 = self.conv6(x6) |
|
|
83 |
|
|
|
84 |
p7 = self.up_pool7(x6) |
|
|
85 |
x7 = torch.cat([p7, x3], dim=1) |
|
|
86 |
x7 = self.conv7(x7) |
|
|
87 |
|
|
|
88 |
p8 = self.up_pool8(x7) |
|
|
89 |
x8 = torch.cat([p8, x2], dim=1) |
|
|
90 |
x8 = self.conv8(x8) |
|
|
91 |
|
|
|
92 |
p9 = self.up_pool9(x8) |
|
|
93 |
x9 = torch.cat([p9, x1], dim=1) |
|
|
94 |
x9 = self.conv9(x9) |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
output = self.conv10(x9) |
|
|
98 |
output = F.sigmoid(output) |
|
|
99 |
|
|
|
100 |
return output |