Diff of /U-Net/unet/unet_parts.py [000000] .. [6f3ba0]

Switch to side-by-side view

--- a
+++ b/U-Net/unet/unet_parts.py
@@ -0,0 +1,77 @@
+""" Parts of the U-Net model """
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DoubleConv(nn.Module):
+    """(convolution => [BN] => ReLU) * 2"""
+
+    def __init__(self, in_channels, out_channels, mid_channels=None):
+        super().__init__()
+        if not mid_channels:
+            mid_channels = out_channels
+        self.double_conv = nn.Sequential(
+            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(mid_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True)
+        )
+
+    def forward(self, x):
+        return self.double_conv(x)
+
+
+class Down(nn.Module):
+    """Downscaling with maxpool then double conv"""
+
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+        self.maxpool_conv = nn.Sequential(
+            nn.MaxPool2d(2),
+            DoubleConv(in_channels, out_channels)
+        )
+
+    def forward(self, x):
+        return self.maxpool_conv(x)
+
+
+class Up(nn.Module):
+    """Upscaling then double conv"""
+
+    def __init__(self, in_channels, out_channels, bilinear=True):
+        super().__init__()
+
+        # if bilinear, use the normal convolutions to reduce the number of channels
+        if bilinear:
+            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
+        else:
+            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
+            self.conv = DoubleConv(in_channels, out_channels)
+
+    def forward(self, x1, x2):
+        x1 = self.up(x1)
+        # input is CHW
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+
+        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
+                        diffY // 2, diffY - diffY // 2])
+        # if you have padding issues, see
+        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
+        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
+        x = torch.cat([x2, x1], dim=1)
+        return self.conv(x)
+
+
+class OutConv(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(OutConv, self).__init__()
+        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
+
+    def forward(self, x):
+        return self.conv(x)