Diff of /model.py [000000] .. [9ff54e]

Switch to side-by-side view

--- a
+++ b/model.py
@@ -0,0 +1,205 @@
+import sys
+
+from torch.distributions.normal import Normal
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Encoder(nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size=3,
+                 stride=1,
+                 padding=1,
+                 bias=True,
+                 bn=False,
+                 num_groups=8):
+        super(Encoder, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.relu = nn.ReLU()
+        self.conv1 = nn.Conv3d(in_channels,
+                               out_channels,
+                               kernel_size,
+                               stride=stride,
+                               padding=padding,
+                               bias=bias)
+        self.gn1 = nn.GroupNorm(num_groups, out_channels)
+        self.conv2 = nn.Conv3d(out_channels,
+                               out_channels,
+                               kernel_size,
+                               stride=stride,
+                               padding=padding,
+                               bias=bias)
+        self.gn2 = nn.GroupNorm(num_groups, out_channels)
+
+    def forward(self, x):
+        identity = x
+        res = self.relu(x)
+        res = self.conv1(res)
+        res = self.gn1(res)
+        res = self.relu(res)
+        res = self.conv2(res)
+        res = self.gn2(res)
+        res = self.relu(res)
+        if self.in_channels != self.out_channels:
+            pad = [0] * (2 * len(identity.size()))
+            pad[6] = (self.out_channels - self.in_channels)
+            identity = F.pad(input=identity, pad=pad, mode='constant', value=0)
+        return res + identity
+
+
+class UNet3D(nn.Module):
+    def __init__(self,
+                 in_channel,
+                 n_classes,
+                 use_bias=True,
+                 inplanes=32,
+                 num_groups=8):
+        self.in_channel = in_channel
+        self.n_classes = n_classes
+        self.inplanes = inplanes
+        self.num_groups = num_groups
+        planes = [inplanes * int(pow(2, i)) for i in range(0, 5)]
+        super(UNet3D, self).__init__()
+        self.ec0 = Encoder(in_channel,
+                           planes[1],
+                           bias=use_bias,
+                           num_groups=num_groups)
+        self.ec1 = Encoder(planes[1],
+                           planes[2],
+                           bias=use_bias,
+                           num_groups=num_groups)
+        self.ec1_2 = Encoder(planes[2],
+                             planes[2],
+                             bias=use_bias,
+                             num_groups=num_groups)
+        self.ec2 = Encoder(planes[2],
+                           planes[3],
+                           bias=use_bias,
+                           num_groups=num_groups)
+        self.ec2_2 = Encoder(planes[3],
+                             planes[3],
+                             bias=use_bias,
+                             num_groups=num_groups)
+        self.ec3 = Encoder(planes[3],
+                           planes[4],
+                           bias=use_bias,
+                           num_groups=num_groups)
+        self.ec3_2 = Encoder(planes[4],
+                             planes[4],
+                             bias=use_bias,
+                             num_groups=num_groups)
+        self.maxpool = nn.MaxPool3d(2)
+        self.dc3 = Encoder(planes[4],
+                           planes[4],
+                           bias=use_bias,
+                           num_groups=num_groups)
+        self.dc3_2 = Encoder(planes[4],
+                             planes[4],
+                             bias=use_bias,
+                             num_groups=num_groups)
+        self.up3 = self.decoder(planes[4],
+                                planes[3],
+                                kernel_size=2,
+                                stride=2,
+                                bias=use_bias)
+        self.dc2 = Encoder(planes[4],
+                           planes[3],
+                           bias=use_bias,
+                           num_groups=num_groups)
+        self.dc2_2 = Encoder(planes[3],
+                             planes[3],
+                             bias=use_bias,
+                             num_groups=num_groups)
+        self.up2 = self.decoder(planes[3],
+                                planes[2],
+                                kernel_size=2,
+                                stride=2,
+                                bias=use_bias)
+        self.dc1 = Encoder(planes[3],
+                           planes[2],
+                           bias=use_bias,
+                           num_groups=num_groups)
+        self.dc1_2 = Encoder(planes[2],
+                             planes[2],
+                             bias=use_bias,
+                             num_groups=num_groups)
+        self.up1 = self.decoder(planes[2],
+                                planes[1],
+                                kernel_size=2,
+                                stride=2,
+                                bias=use_bias)
+        self.dc0a = Encoder(planes[2],
+                            planes[1],
+                            bias=use_bias,
+                            num_groups=num_groups)
+        self.dc0b = self.decoder(planes[1],
+                                 n_classes,
+                                 kernel_size=1,
+                                 stride=1,
+                                 bias=use_bias,
+                                 relu=False)
+        for m in self.modules():
+            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
+                nn.init.kaiming_normal_(m.weight,
+                                        mode='fan_out',
+                                        nonlinearity='relu')
+            elif isinstance(m, nn.GroupNorm):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def decoder(self,
+                in_channels,
+                out_channels,
+                kernel_size,
+                stride=1,
+                padding=0,
+                output_padding=0,
+                bias=True,
+                relu=True):
+        layer = [
+            nn.ConvTranspose3d(in_channels,
+                               out_channels,
+                               kernel_size,
+                               stride=stride,
+                               padding=padding,
+                               output_padding=output_padding,
+                               bias=bias),
+        ]
+        if relu:
+            layer.append(nn.GroupNorm(self.num_groups, out_channels))
+            layer.append(nn.ReLU())
+        layer = nn.Sequential(*layer)
+        return layer
+
+    def forward(self, x):
+        e0 = self.ec0(x)
+        e1 = self.ec1_2(self.ec1(self.maxpool(e0)))
+        e2 = self.ec2_2(self.ec2(self.maxpool(e1)))
+        e3 = self.ec3_2(self.ec3(self.maxpool(e2)))
+        d3 = self.up3(self.dc3_2(self.dc3(e3)))
+        if d3.size()[2:] != e2.size()[2:]:
+            d3 = F.interpolate(d3,
+                               e2.size()[2:],
+                               mode='trilinear',
+                               align_corners=False)
+        d3 = torch.cat((d3, e2), 1)
+        d2 = self.up2(self.dc2_2(self.dc2(d3)))
+        if d2.size()[2:] != e1.size()[2:]:
+            d2 = F.interpolate(d2,
+                               e1.size()[2:],
+                               mode='trilinear',
+                               align_corners=False)
+        d2 = torch.cat((d2, e1), 1)
+        d1 = self.up1(self.dc1_2(self.dc1(d2)))
+        if d1.size()[2:] != e0.size()[2:]:
+            d1 = F.interpolate(d1,
+                               e0.size()[2:],
+                               mode='trilinear',
+                               align_corners=False)
+        d1 = torch.cat((d1, e0), 1)
+        d0 = self.dc0b(self.dc0a(d1))
+        return d0