Diff of /model.py [000000] .. [c854d3]

Switch to side-by-side view

--- a
+++ b/model.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
+class ConvUnit(nn.Module):
+  """
+    Convolution Unit -
+    for  now : (Conv3D -> BatchNorm -> ReLu) * 2
+    Try modifying to Residual convolutions
+  """
+
+  def __init__(self, in_channels, out_channels):
+    super(ConvUnit, self).__init__()
+    self.double_conv = nn.Sequential(
+
+        nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
+        nn.BatchNorm3d(out_channels),
+        nn.ReLU(inplace=True), # inplace=True means it changes the input directly, input is lost
+
+        nn.Conv3d(out_channels, out_channels, kernel_size = 3, padding = 1),
+        nn.BatchNorm3d(out_channels),
+        nn.ReLU(inplace=True)
+      )
+
+  def forward(self,x):
+    return self.double_conv(x)
+
+
+
+class EncoderUnit(nn.Module):
+  """
+    An Encoder Unit with the ConvUnit and MaxPool
+  """
+  def __init__(self, in_channels, out_channels):
+    super(EncoderUnit, self).__init__()
+    self.encoder = nn.Sequential(
+        nn.MaxPool3d(2),
+        ConvUnit(in_channels, out_channels)
+    )
+  def forward(self, x):
+    return self.encoder(x)
+
+
+class DecoderUnit(nn.Module):
+  """
+    ConvUnit and upsample with Upsample or convTranspose
+
+  """
+  def __init__(self, in_channels, out_channels, bilinear=False):
+    super().__init__()
+
+    if bilinear:
+      # Only for 2D model
+      self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
+    else:
+      self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
+
+    self.conv = ConvUnit(in_channels, out_channels)
+
+  def forward(self, x1, x2):
+
+      x1 = self.up(x1)
+
+      diffZ = x2.size()[2] - x1.size()[2]
+      diffY = x2.size()[3] - x1.size()[3]
+      diffX = x2.size()[4] - x1.size()[4]
+      x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
+
+      x = torch.cat([x2, x1], dim=1)
+      return self.conv(x)
+
+class OutConv(nn.Module):
+  def __init__(self, in_channels, out_channels):
+    super(OutConv, self).__init__()
+    self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)
+
+  def forward(self, x):
+    return self.conv(x)
+
+
+
+
+###########   Model :
+
+class UNet(nn.Module):
+
+  def __init__(self, in_channels, n_classes, s_channels, bilinear = False):
+    super(UNet, self).__init__()
+    self.in_channels = in_channels
+    self.n_classes = n_classes
+    self.s_channels = s_channels
+    self.bilinear = bilinear
+
+    self.conv = ConvUnit(in_channels, s_channels)
+    self.enc1 = EncoderUnit(s_channels, 2 * s_channels)
+    self.enc2 = EncoderUnit(2 * s_channels, 4 * s_channels)
+    self.enc3 = EncoderUnit(4 * s_channels, 8 * s_channels)
+    self.enc4 = EncoderUnit(8 * s_channels, 8 * s_channels)
+
+    self.dec1 = DecoderUnit(16 * s_channels, 4 * s_channels, self.bilinear)
+    self.dec2 = DecoderUnit(8 * s_channels, 2 * s_channels, self.bilinear)
+    self.dec3 = DecoderUnit(4 * s_channels, s_channels, self.bilinear)
+    self.dec4 = DecoderUnit(2 * s_channels, s_channels, self.bilinear)
+    self.out = OutConv(s_channels, n_classes)
+
+  def forward(self, x):
+    x1 = self.conv(x)
+    x2 = self.enc1(x1)
+    x3 = self.enc2(x2)
+    x4 = self.enc3(x3)
+    x5 = self.enc4(x4)
+
+    mask = self.dec1(x5, x4)
+    mask = self.dec2(mask, x3)
+    mask = self.dec3(mask, x2)
+    mask = self.dec4(mask, x1)
+    mask = self.out(mask)
+
+    return mask