Diff of /networks/vnet_sdf.py [000000] .. [903821]

Switch to side-by-side view

--- a
+++ b/networks/vnet_sdf.py
@@ -0,0 +1,266 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+"""
+Differences with V-Net
+Adding nn.Tanh in the end of the conv. to make the outputs in [-1, 1].
+"""
+
+class ConvBlock(nn.Module):
+    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
+        super(ConvBlock, self).__init__()
+
+        ops = []
+        for i in range(n_stages):
+            if i==0:
+                input_channel = n_filters_in
+            else:
+                input_channel = n_filters_out
+
+            ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
+            if normalization == 'batchnorm':
+                ops.append(nn.BatchNorm3d(n_filters_out))
+            elif normalization == 'groupnorm':
+                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
+            elif normalization == 'instancenorm':
+                ops.append(nn.InstanceNorm3d(n_filters_out))
+            elif normalization != 'none':
+                assert False
+            ops.append(nn.ReLU(inplace=True))
+
+        self.conv = nn.Sequential(*ops)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class ResidualConvBlock(nn.Module):
+    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
+        super(ResidualConvBlock, self).__init__()
+
+        ops = []
+        for i in range(n_stages):
+            if i == 0:
+                input_channel = n_filters_in
+            else:
+                input_channel = n_filters_out
+
+            ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
+            if normalization == 'batchnorm':
+                ops.append(nn.BatchNorm3d(n_filters_out))
+            elif normalization == 'groupnorm':
+                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
+            elif normalization == 'instancenorm':
+                ops.append(nn.InstanceNorm3d(n_filters_out))
+            elif normalization != 'none':
+                assert False
+
+            if i != n_stages-1:
+                ops.append(nn.ReLU(inplace=True))
+
+        self.conv = nn.Sequential(*ops)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = (self.conv(x) + x)
+        x = self.relu(x)
+        return x
+
+
+class DownsamplingConvBlock(nn.Module):
+    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
+        super(DownsamplingConvBlock, self).__init__()
+
+        ops = []
+        if normalization != 'none':
+            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
+            if normalization == 'batchnorm':
+                ops.append(nn.BatchNorm3d(n_filters_out))
+            elif normalization == 'groupnorm':
+                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
+            elif normalization == 'instancenorm':
+                ops.append(nn.InstanceNorm3d(n_filters_out))
+            else:
+                assert False
+        else:
+            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
+
+        ops.append(nn.ReLU(inplace=True))
+
+        self.conv = nn.Sequential(*ops)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class UpsamplingDeconvBlock(nn.Module):
+    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
+        super(UpsamplingDeconvBlock, self).__init__()
+
+        ops = []
+        if normalization != 'none':
+            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
+            if normalization == 'batchnorm':
+                ops.append(nn.BatchNorm3d(n_filters_out))
+            elif normalization == 'groupnorm':
+                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
+            elif normalization == 'instancenorm':
+                ops.append(nn.InstanceNorm3d(n_filters_out))
+            else:
+                assert False
+        else:
+            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
+
+        ops.append(nn.ReLU(inplace=True))
+
+        self.conv = nn.Sequential(*ops)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class Upsampling(nn.Module):
+    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
+        super(Upsampling, self).__init__()
+
+        ops = []
+        ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False))
+        ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
+        if normalization == 'batchnorm':
+            ops.append(nn.BatchNorm3d(n_filters_out))
+        elif normalization == 'groupnorm':
+            ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
+        elif normalization == 'instancenorm':
+            ops.append(nn.InstanceNorm3d(n_filters_out))
+        elif normalization != 'none':
+            assert False
+        ops.append(nn.ReLU(inplace=True))
+
+        self.conv = nn.Sequential(*ops)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class VNet(nn.Module):
+    def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False):
+        super(VNet, self).__init__()
+        self.has_dropout = has_dropout
+        convBlock = ConvBlock if not has_residual else ResidualConvBlock
+
+        self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization)
+        self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
+
+        self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
+        self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
+
+        self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
+        self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
+
+        self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
+        self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
+
+        self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
+        self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
+
+        self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
+        self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
+
+        self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
+        self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
+
+        self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
+        self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
+
+        self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization)
+        self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
+        self.out_conv2 = nn.Conv3d(n_filters, n_classes, 1, padding=0)
+        self.tanh = nn.Tanh()
+
+        self.dropout = nn.Dropout3d(p=0.5, inplace=False)
+        # self.__init_weight()
+
+    def encoder(self, input):
+        x1 = self.block_one(input)
+        x1_dw = self.block_one_dw(x1)
+
+        x2 = self.block_two(x1_dw)
+        x2_dw = self.block_two_dw(x2)
+
+        x3 = self.block_three(x2_dw)
+        x3_dw = self.block_three_dw(x3)
+
+        x4 = self.block_four(x3_dw)
+        x4_dw = self.block_four_dw(x4)
+
+        x5 = self.block_five(x4_dw)
+        # x5 = F.dropout3d(x5, p=0.5, training=True)
+        if self.has_dropout:
+            x5 = self.dropout(x5)
+
+        res = [x1, x2, x3, x4, x5]
+
+        return res
+
+    def decoder(self, features):
+        x1 = features[0]
+        x2 = features[1]
+        x3 = features[2]
+        x4 = features[3]
+        x5 = features[4]
+
+        x5_up = self.block_five_up(x5)
+        x5_up = x5_up + x4
+
+        x6 = self.block_six(x5_up)
+        x6_up = self.block_six_up(x6)
+        x6_up = x6_up + x3
+
+        x7 = self.block_seven(x6_up)
+        x7_up = self.block_seven_up(x7)
+        x7_up = x7_up + x2
+
+        x8 = self.block_eight(x7_up)
+        x8_up = self.block_eight_up(x8)
+        x8_up = x8_up + x1
+        x9 = self.block_nine(x8_up)
+        # x9 = F.dropout3d(x9, p=0.5, training=True)
+        if self.has_dropout:
+            x9 = self.dropout(x9)
+        out = self.out_conv(x9)
+        out_tanh = self.tanh(out)
+        out_seg = self.out_conv2(x9)
+        return out_tanh, out_seg
+
+
+    def forward(self, input, turnoff_drop=False):
+        if turnoff_drop:
+            has_dropout = self.has_dropout
+            self.has_dropout = False
+        features = self.encoder(input)
+        out_tanh, out_seg = self.decoder(features)
+        if turnoff_drop:
+            self.has_dropout = has_dropout
+        return out_tanh,out_seg
+
+    # def __init_weight(self):
+    #     for m in self.modules():
+    #         if isinstance(m, nn.Conv3d):
+    #             torch.nn.init.kaiming_normal_(m.weight)
+    #         elif isinstance(m, nn.BatchNorm3d):
+    #             m.weight.data.fill_(1)
+
+# if __name__ == '__main__':
+#     # compute FLOPS & PARAMETERS
+#     # from thop import profile
+#     # from thop import clever_format
+#     # model = VNet(n_channels=1, n_classes=2)
+#     # input = torch.randn(4, 1, 112, 112, 80)
+#     # flops, params = profile(model, inputs=(input,))
+#     # macs, params = clever_format([flops, params], "%.3f")
+#     # print(macs, params)