--- a +++ b/unet.py @@ -0,0 +1,98 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + + +class UNet(nn.Module): + + def __init__(self, in_channels=3, out_channels=1, init_features=32): + super(UNet, self).__init__() + + features = init_features + self.encoder1 = UNet._block(in_channels, features, name="enc1") + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.encoder2 = UNet._block(features, features * 2, name="enc2") + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") + self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) + self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") + self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") + + self.upconv4 = nn.ConvTranspose2d( + features * 16, features * 8, kernel_size=2, stride=2 + ) + self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") + self.upconv3 = nn.ConvTranspose2d( + features * 8, features * 4, kernel_size=2, stride=2 + ) + self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") + self.upconv2 = nn.ConvTranspose2d( + features * 4, features * 2, kernel_size=2, stride=2 + ) + self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") + self.upconv1 = nn.ConvTranspose2d( + features * 2, features, kernel_size=2, stride=2 + ) + self.decoder1 = UNet._block(features * 2, features, name="dec1") + + self.conv = nn.Conv2d( + in_channels=features, out_channels=out_channels, kernel_size=1 + ) + + def forward(self, x): + enc1 = self.encoder1(x) + enc2 = self.encoder2(self.pool1(enc1)) + enc3 = self.encoder3(self.pool2(enc2)) + enc4 = self.encoder4(self.pool3(enc3)) + + bottleneck = self.bottleneck(self.pool4(enc4)) + + dec4 = self.upconv4(bottleneck) + dec4 = torch.cat((dec4, enc4), dim=1) + dec4 = self.decoder4(dec4) + dec3 = self.upconv3(dec4) + dec3 = torch.cat((dec3, enc3), dim=1) + dec3 = self.decoder3(dec3) + dec2 = self.upconv2(dec3) + dec2 = torch.cat((dec2, enc2), dim=1) + dec2 = self.decoder2(dec2) + dec1 = self.upconv1(dec2) + dec1 = torch.cat((dec1, enc1), dim=1) + dec1 = self.decoder1(dec1) + return torch.sigmoid(self.conv(dec1)) + + @staticmethod + def _block(in_channels, features, name): + return nn.Sequential( + OrderedDict( + [ + ( + name + "conv1", + nn.Conv2d( + in_channels=in_channels, + out_channels=features, + kernel_size=3, + padding=1, + bias=False, + ), + ), + (name + "norm1", nn.BatchNorm2d(num_features=features)), + (name + "relu1", nn.ReLU(inplace=True)), + ( + name + "conv2", + nn.Conv2d( + in_channels=features, + out_channels=features, + kernel_size=3, + padding=1, + bias=False, + ), + ), + (name + "norm2", nn.BatchNorm2d(num_features=features)), + (name + "relu2", nn.ReLU(inplace=True)), + ] + ) + )