|
a |
|
b/unet.py |
|
|
1 |
from collections import OrderedDict |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
class UNet(nn.Module): |
|
|
8 |
|
|
|
9 |
def __init__(self, in_channels=3, out_channels=1, init_features=32): |
|
|
10 |
super(UNet, self).__init__() |
|
|
11 |
|
|
|
12 |
features = init_features |
|
|
13 |
self.encoder1 = UNet._block(in_channels, features, name="enc1") |
|
|
14 |
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
15 |
self.encoder2 = UNet._block(features, features * 2, name="enc2") |
|
|
16 |
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
17 |
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") |
|
|
18 |
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
19 |
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") |
|
|
20 |
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
21 |
|
|
|
22 |
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") |
|
|
23 |
|
|
|
24 |
self.upconv4 = nn.ConvTranspose2d( |
|
|
25 |
features * 16, features * 8, kernel_size=2, stride=2 |
|
|
26 |
) |
|
|
27 |
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") |
|
|
28 |
self.upconv3 = nn.ConvTranspose2d( |
|
|
29 |
features * 8, features * 4, kernel_size=2, stride=2 |
|
|
30 |
) |
|
|
31 |
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") |
|
|
32 |
self.upconv2 = nn.ConvTranspose2d( |
|
|
33 |
features * 4, features * 2, kernel_size=2, stride=2 |
|
|
34 |
) |
|
|
35 |
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") |
|
|
36 |
self.upconv1 = nn.ConvTranspose2d( |
|
|
37 |
features * 2, features, kernel_size=2, stride=2 |
|
|
38 |
) |
|
|
39 |
self.decoder1 = UNet._block(features * 2, features, name="dec1") |
|
|
40 |
|
|
|
41 |
self.conv = nn.Conv2d( |
|
|
42 |
in_channels=features, out_channels=out_channels, kernel_size=1 |
|
|
43 |
) |
|
|
44 |
|
|
|
45 |
def forward(self, x): |
|
|
46 |
enc1 = self.encoder1(x) |
|
|
47 |
enc2 = self.encoder2(self.pool1(enc1)) |
|
|
48 |
enc3 = self.encoder3(self.pool2(enc2)) |
|
|
49 |
enc4 = self.encoder4(self.pool3(enc3)) |
|
|
50 |
|
|
|
51 |
bottleneck = self.bottleneck(self.pool4(enc4)) |
|
|
52 |
|
|
|
53 |
dec4 = self.upconv4(bottleneck) |
|
|
54 |
dec4 = torch.cat((dec4, enc4), dim=1) |
|
|
55 |
dec4 = self.decoder4(dec4) |
|
|
56 |
dec3 = self.upconv3(dec4) |
|
|
57 |
dec3 = torch.cat((dec3, enc3), dim=1) |
|
|
58 |
dec3 = self.decoder3(dec3) |
|
|
59 |
dec2 = self.upconv2(dec3) |
|
|
60 |
dec2 = torch.cat((dec2, enc2), dim=1) |
|
|
61 |
dec2 = self.decoder2(dec2) |
|
|
62 |
dec1 = self.upconv1(dec2) |
|
|
63 |
dec1 = torch.cat((dec1, enc1), dim=1) |
|
|
64 |
dec1 = self.decoder1(dec1) |
|
|
65 |
return torch.sigmoid(self.conv(dec1)) |
|
|
66 |
|
|
|
67 |
@staticmethod |
|
|
68 |
def _block(in_channels, features, name): |
|
|
69 |
return nn.Sequential( |
|
|
70 |
OrderedDict( |
|
|
71 |
[ |
|
|
72 |
( |
|
|
73 |
name + "conv1", |
|
|
74 |
nn.Conv2d( |
|
|
75 |
in_channels=in_channels, |
|
|
76 |
out_channels=features, |
|
|
77 |
kernel_size=3, |
|
|
78 |
padding=1, |
|
|
79 |
bias=False, |
|
|
80 |
), |
|
|
81 |
), |
|
|
82 |
(name + "norm1", nn.BatchNorm2d(num_features=features)), |
|
|
83 |
(name + "relu1", nn.ReLU(inplace=True)), |
|
|
84 |
( |
|
|
85 |
name + "conv2", |
|
|
86 |
nn.Conv2d( |
|
|
87 |
in_channels=features, |
|
|
88 |
out_channels=features, |
|
|
89 |
kernel_size=3, |
|
|
90 |
padding=1, |
|
|
91 |
bias=False, |
|
|
92 |
), |
|
|
93 |
), |
|
|
94 |
(name + "norm2", nn.BatchNorm2d(num_features=features)), |
|
|
95 |
(name + "relu2", nn.ReLU(inplace=True)), |
|
|
96 |
] |
|
|
97 |
) |
|
|
98 |
) |