Diff of /model.py [000000] .. [c854d3]

Switch to unified view

a b/model.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
6
7
class ConvUnit(nn.Module):
8
  """
9
    Convolution Unit -
10
    for  now : (Conv3D -> BatchNorm -> ReLu) * 2
11
    Try modifying to Residual convolutions
12
  """
13
14
  def __init__(self, in_channels, out_channels):
15
    super(ConvUnit, self).__init__()
16
    self.double_conv = nn.Sequential(
17
18
        nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
19
        nn.BatchNorm3d(out_channels),
20
        nn.ReLU(inplace=True), # inplace=True means it changes the input directly, input is lost
21
22
        nn.Conv3d(out_channels, out_channels, kernel_size = 3, padding = 1),
23
        nn.BatchNorm3d(out_channels),
24
        nn.ReLU(inplace=True)
25
      )
26
27
  def forward(self,x):
28
    return self.double_conv(x)
29
30
31
32
class EncoderUnit(nn.Module):
33
  """
34
    An Encoder Unit with the ConvUnit and MaxPool
35
  """
36
  def __init__(self, in_channels, out_channels):
37
    super(EncoderUnit, self).__init__()
38
    self.encoder = nn.Sequential(
39
        nn.MaxPool3d(2),
40
        ConvUnit(in_channels, out_channels)
41
    )
42
  def forward(self, x):
43
    return self.encoder(x)
44
45
46
class DecoderUnit(nn.Module):
47
  """
48
    ConvUnit and upsample with Upsample or convTranspose
49
50
  """
51
  def __init__(self, in_channels, out_channels, bilinear=False):
52
    super().__init__()
53
54
    if bilinear:
55
      # Only for 2D model
56
      self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
57
    else:
58
      self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
59
60
    self.conv = ConvUnit(in_channels, out_channels)
61
62
  def forward(self, x1, x2):
63
64
      x1 = self.up(x1)
65
66
      diffZ = x2.size()[2] - x1.size()[2]
67
      diffY = x2.size()[3] - x1.size()[3]
68
      diffX = x2.size()[4] - x1.size()[4]
69
      x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
70
71
      x = torch.cat([x2, x1], dim=1)
72
      return self.conv(x)
73
74
class OutConv(nn.Module):
75
  def __init__(self, in_channels, out_channels):
76
    super(OutConv, self).__init__()
77
    self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)
78
79
  def forward(self, x):
80
    return self.conv(x)
81
82
83
84
85
###########   Model :
86
87
class UNet(nn.Module):
88
89
  def __init__(self, in_channels, n_classes, s_channels, bilinear = False):
90
    super(UNet, self).__init__()
91
    self.in_channels = in_channels
92
    self.n_classes = n_classes
93
    self.s_channels = s_channels
94
    self.bilinear = bilinear
95
96
    self.conv = ConvUnit(in_channels, s_channels)
97
    self.enc1 = EncoderUnit(s_channels, 2 * s_channels)
98
    self.enc2 = EncoderUnit(2 * s_channels, 4 * s_channels)
99
    self.enc3 = EncoderUnit(4 * s_channels, 8 * s_channels)
100
    self.enc4 = EncoderUnit(8 * s_channels, 8 * s_channels)
101
102
    self.dec1 = DecoderUnit(16 * s_channels, 4 * s_channels, self.bilinear)
103
    self.dec2 = DecoderUnit(8 * s_channels, 2 * s_channels, self.bilinear)
104
    self.dec3 = DecoderUnit(4 * s_channels, s_channels, self.bilinear)
105
    self.dec4 = DecoderUnit(2 * s_channels, s_channels, self.bilinear)
106
    self.out = OutConv(s_channels, n_classes)
107
108
  def forward(self, x):
109
    x1 = self.conv(x)
110
    x2 = self.enc1(x1)
111
    x3 = self.enc2(x2)
112
    x4 = self.enc3(x3)
113
    x5 = self.enc4(x4)
114
115
    mask = self.dec1(x5, x4)
116
    mask = self.dec2(mask, x3)
117
    mask = self.dec3(mask, x2)
118
    mask = self.dec4(mask, x1)
119
    mask = self.out(mask)
120
121
    return mask