|
a |
|
b/pathflowai/unet.py |
|
|
1 |
# From https://raw.githubusercontent.com/milesial/Pytorch-UNet/master/unet/unet_model.py |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import torch.nn.functional as F |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class double_conv(nn.Module): |
|
|
9 |
'''(conv => BN => ReLU) * 2''' |
|
|
10 |
def __init__(self, in_ch, out_ch): |
|
|
11 |
super(double_conv, self).__init__() |
|
|
12 |
self.conv = nn.Sequential( |
|
|
13 |
nn.Conv2d(in_ch, out_ch, 3, padding=1), |
|
|
14 |
nn.BatchNorm2d(out_ch), |
|
|
15 |
nn.ReLU(inplace=True), |
|
|
16 |
nn.Conv2d(out_ch, out_ch, 3, padding=1), |
|
|
17 |
nn.BatchNorm2d(out_ch), |
|
|
18 |
nn.ReLU(inplace=True) |
|
|
19 |
) |
|
|
20 |
|
|
|
21 |
def forward(self, x): |
|
|
22 |
x = self.conv(x) |
|
|
23 |
return x |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
class inconv(nn.Module): |
|
|
27 |
def __init__(self, in_ch, out_ch): |
|
|
28 |
super(inconv, self).__init__() |
|
|
29 |
self.conv = double_conv(in_ch, out_ch) |
|
|
30 |
|
|
|
31 |
def forward(self, x): |
|
|
32 |
x = self.conv(x) |
|
|
33 |
return x |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class down(nn.Module): |
|
|
37 |
def __init__(self, in_ch, out_ch): |
|
|
38 |
super(down, self).__init__() |
|
|
39 |
self.mpconv = nn.Sequential( |
|
|
40 |
nn.MaxPool2d(2), |
|
|
41 |
double_conv(in_ch, out_ch) |
|
|
42 |
) |
|
|
43 |
|
|
|
44 |
def forward(self, x): |
|
|
45 |
x = self.mpconv(x) |
|
|
46 |
return x |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
class up(nn.Module): |
|
|
50 |
def __init__(self, in_ch, out_ch, bilinear=True): |
|
|
51 |
super(up, self).__init__() |
|
|
52 |
|
|
|
53 |
# would be a nice idea if the upsampling could be learned too, |
|
|
54 |
# but my machine do not have enough memory to handle all those weights |
|
|
55 |
if bilinear: |
|
|
56 |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
|
57 |
else: |
|
|
58 |
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) |
|
|
59 |
|
|
|
60 |
self.conv = double_conv(in_ch, out_ch) |
|
|
61 |
|
|
|
62 |
def forward(self, x1, x2): |
|
|
63 |
x1 = self.up(x1) |
|
|
64 |
|
|
|
65 |
# input is CHW |
|
|
66 |
diffY = x2.size()[2] - x1.size()[2] |
|
|
67 |
diffX = x2.size()[3] - x1.size()[3] |
|
|
68 |
|
|
|
69 |
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, |
|
|
70 |
diffY // 2, diffY - diffY//2)) |
|
|
71 |
|
|
|
72 |
# for padding issues, see |
|
|
73 |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a |
|
|
74 |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd |
|
|
75 |
|
|
|
76 |
x = torch.cat([x2, x1], dim=1) |
|
|
77 |
x = self.conv(x) |
|
|
78 |
return x |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
class outconv(nn.Module): |
|
|
82 |
def __init__(self, in_ch, out_ch): |
|
|
83 |
super(outconv, self).__init__() |
|
|
84 |
self.conv = nn.Conv2d(in_ch, out_ch, 1) |
|
|
85 |
|
|
|
86 |
def forward(self, x): |
|
|
87 |
x = self.conv(x) |
|
|
88 |
return x |
|
|
89 |
|
|
|
90 |
class UNet(nn.Module): |
|
|
91 |
def __init__(self, n_channels, n_classes, use_sigmoid=False, use_softmax=False): |
|
|
92 |
super(UNet, self).__init__() |
|
|
93 |
self.inc = inconv(n_channels, 64) |
|
|
94 |
self.down1 = down(64, 128) |
|
|
95 |
self.down2 = down(128, 256) |
|
|
96 |
self.down3 = down(256, 512) |
|
|
97 |
self.down4 = down(512, 512) |
|
|
98 |
self.up1 = up(1024, 256) |
|
|
99 |
self.up2 = up(512, 128) |
|
|
100 |
self.up3 = up(256, 64) |
|
|
101 |
self.up4 = up(128, 64) |
|
|
102 |
self.outc = outconv(64, n_classes) |
|
|
103 |
self.sigmoid = nn.Sequential(nn.Sigmoid() if use_sigmoid else nn.Dropout(p=0.),nn.LogSoftmax(dim=1) if use_softmax else nn.Dropout(p=0.)) |
|
|
104 |
|
|
|
105 |
def forward(self, x): |
|
|
106 |
x1 = self.inc(x) |
|
|
107 |
x2 = self.down1(x1) |
|
|
108 |
x3 = self.down2(x2) |
|
|
109 |
x4 = self.down3(x3) |
|
|
110 |
x5 = self.down4(x4) |
|
|
111 |
x = self.up1(x5, x4) |
|
|
112 |
x = self.up2(x, x3) |
|
|
113 |
x = self.up3(x, x2) |
|
|
114 |
x = self.up4(x, x1) |
|
|
115 |
x = self.outc(x) |
|
|
116 |
return self.sigmoid(x) |