a b/unet.py
1
from collections import OrderedDict
2
3
import torch
4
import torch.nn as nn
5
6
7
class UNet(nn.Module):
8
9
    def __init__(self, in_channels=3, out_channels=1, init_features=32):
10
        super(UNet, self).__init__()
11
12
        features = init_features
13
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
14
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
15
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
16
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
17
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
18
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
19
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
20
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
21
22
        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
23
24
        self.upconv4 = nn.ConvTranspose2d(
25
            features * 16, features * 8, kernel_size=2, stride=2
26
        )
27
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
28
        self.upconv3 = nn.ConvTranspose2d(
29
            features * 8, features * 4, kernel_size=2, stride=2
30
        )
31
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
32
        self.upconv2 = nn.ConvTranspose2d(
33
            features * 4, features * 2, kernel_size=2, stride=2
34
        )
35
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
36
        self.upconv1 = nn.ConvTranspose2d(
37
            features * 2, features, kernel_size=2, stride=2
38
        )
39
        self.decoder1 = UNet._block(features * 2, features, name="dec1")
40
41
        self.conv = nn.Conv2d(
42
            in_channels=features, out_channels=out_channels, kernel_size=1
43
        )
44
45
    def forward(self, x):
46
        enc1 = self.encoder1(x)
47
        enc2 = self.encoder2(self.pool1(enc1))
48
        enc3 = self.encoder3(self.pool2(enc2))
49
        enc4 = self.encoder4(self.pool3(enc3))
50
51
        bottleneck = self.bottleneck(self.pool4(enc4))
52
53
        dec4 = self.upconv4(bottleneck)
54
        dec4 = torch.cat((dec4, enc4), dim=1)
55
        dec4 = self.decoder4(dec4)
56
        dec3 = self.upconv3(dec4)
57
        dec3 = torch.cat((dec3, enc3), dim=1)
58
        dec3 = self.decoder3(dec3)
59
        dec2 = self.upconv2(dec3)
60
        dec2 = torch.cat((dec2, enc2), dim=1)
61
        dec2 = self.decoder2(dec2)
62
        dec1 = self.upconv1(dec2)
63
        dec1 = torch.cat((dec1, enc1), dim=1)
64
        dec1 = self.decoder1(dec1)
65
        return torch.sigmoid(self.conv(dec1))
66
67
    @staticmethod
68
    def _block(in_channels, features, name):
69
        return nn.Sequential(
70
            OrderedDict(
71
                [
72
                    (
73
                        name + "conv1",
74
                        nn.Conv2d(
75
                            in_channels=in_channels,
76
                            out_channels=features,
77
                            kernel_size=3,
78
                            padding=1,
79
                            bias=False,
80
                        ),
81
                    ),
82
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
83
                    (name + "relu1", nn.ReLU(inplace=True)),
84
                    (
85
                        name + "conv2",
86
                        nn.Conv2d(
87
                            in_channels=features,
88
                            out_channels=features,
89
                            kernel_size=3,
90
                            padding=1,
91
                            bias=False,
92
                        ),
93
                    ),
94
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
95
                    (name + "relu2", nn.ReLU(inplace=True)),
96
                ]
97
            )
98
        )