Diff of /coplenet.py [000000] .. [5b2be7]

Switch to side-by-side view

--- a
+++ b/coplenet.py
@@ -0,0 +1,201 @@
+# -*- coding: utf-8 -*-
+# Author: Guotai Wang
+# Date:   12 June, 2020
+# Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images.
+# Reference: 
+#     G. Wang et al. A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions 
+#     from CT Images. IEEE Transactions on Medical Imaging, 2020. DOI:10.1109/TMI.2020.3000314.
+
+from __future__ import print_function, division
+import torch
+import torch.nn as nn
+
+class ConvLayer(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size = 1):
+        super(ConvLayer, self).__init__()
+        padding = int((kernel_size - 1) / 2)
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU()
+        )
+       
+    def forward(self, x):
+        return self.conv(x)
+
+class SEBlock(nn.Module):
+    def __init__(self, in_channels, r):
+        super(SEBlock, self).__init__()
+
+        redu_chns = int(in_channels / r)
+        self.se_layers = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0),
+            nn.LeakyReLU(),
+            nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0),
+            nn.ReLU())
+        
+    def forward(self, x):
+        f = self.se_layers(x)
+        return f*x + x
+
+class ASPPBlock(nn.Module):
+    def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list):
+        super(ASPPBlock, self).__init__()
+        self.conv_num = len(out_channels_list)
+        assert(self.conv_num == 4)
+        assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list))
+        pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0])
+        pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1])
+        pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2])
+        pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3])
+        self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], 
+                    dilation = dilation_list[0], padding = pad0 )
+        self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], 
+                    dilation = dilation_list[1], padding = pad1 )
+        self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], 
+                    dilation = dilation_list[2], padding = pad2 )
+        self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], 
+                    dilation = dilation_list[3], padding = pad3 )
+
+        out_channels  = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] 
+        self.conv_1x1 = nn.Sequential(
+            nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU())
+       
+    def forward(self, x):
+        x1 = self.conv_1(x)
+        x2 = self.conv_2(x)
+        x3 = self.conv_3(x)
+        x4 = self.conv_4(x)
+
+        y  = torch.cat([x1, x2, x3, x4], dim=1)
+        y  = self.conv_1x1(y)
+        return y
+
+class ConvBNActBlock(nn.Module):
+    """Two convolution layers with batch norm, leaky relu, dropout and SE block"""
+    def __init__(self,in_channels, out_channels, dropout_p):
+        super(ConvBNActBlock, self).__init__()
+        self.conv_conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU(),
+            nn.Dropout(dropout_p),
+            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU(),
+            SEBlock(out_channels, 2)
+        )
+       
+    def forward(self, x):
+        return self.conv_conv(x)
+
+class DownBlock(nn.Module):
+    """Downsampling by a concantenation of max-pool and avg-pool, followed by ConvBNActBlock
+    """
+    def __init__(self, in_channels, out_channels, dropout_p):
+        super(DownBlock, self).__init__()
+        self.maxpool = nn.MaxPool2d(2)
+        self.avgpool = nn.AvgPool2d(2)
+        self.conv    = ConvBNActBlock(2 * in_channels, out_channels, dropout_p)
+        
+    def forward(self, x):
+        x_max = self.maxpool(x)
+        x_avg = self.avgpool(x)
+        x_cat = torch.cat([x_max, x_avg], dim=1)
+        y     = self.conv(x_cat)
+        return y + x_cat
+
+class UpBlock(nn.Module):
+    """Upssampling followed by ConvBNActBlock"""
+    def __init__(self, in_channels1, in_channels2, out_channels, 
+                 bilinear=True, dropout_p = 0.5):
+        super(UpBlock, self).__init__()
+        self.bilinear = bilinear
+        if bilinear:
+            self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1)
+            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+        else:
+            self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2)
+        self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p)
+
+    def forward(self, x1, x2):
+        if self.bilinear:
+            x1 = self.conv1x1(x1)
+        x1    = self.up(x1)
+        x_cat = torch.cat([x2, x1], dim=1)
+        y     = self.conv(x_cat)
+        return y + x_cat
+
+class COPLENet(nn.Module):
+    def __init__(self, params):
+        super(COPLENet, self).__init__()
+        self.params    = params
+        self.in_chns   = self.params['in_chns']
+        self.ft_chns   = self.params['feature_chns']
+        self.n_class   = self.params['class_num']
+        self.bilinear  = self.params['bilinear']
+        self.dropout   = self.params['dropout']
+        assert(len(self.ft_chns) == 5)
+
+        f0_half = int(self.ft_chns[0] / 2)
+        f1_half = int(self.ft_chns[1] / 2)
+        f2_half = int(self.ft_chns[2] / 2)
+        f3_half = int(self.ft_chns[3] / 2)
+        self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
+        self.down1  = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
+        self.down2  = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
+        self.down3  = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
+        self.down4  = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
+        
+        self.bridge0= ConvLayer(self.ft_chns[0], f0_half)
+        self.bridge1= ConvLayer(self.ft_chns[1], f1_half)
+        self.bridge2= ConvLayer(self.ft_chns[2], f2_half)
+        self.bridge3= ConvLayer(self.ft_chns[3], f3_half)
+
+        self.up1    = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3])
+        self.up2    = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2])
+        self.up3    = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1])
+        self.up4    = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0])
+
+        f4 = self.ft_chns[4]
+        aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)]
+        aspp_knls = [1, 3, 3, 3]
+        aspp_dila = [1, 2, 4, 6]
+        self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila)
+        
+            
+        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,  
+            kernel_size = 3, padding = 1)
+
+    def forward(self, x):
+        x_shape = list(x.shape)
+        if(len(x_shape) == 5):
+          [N, C, D, H, W] = x_shape
+          new_shape = [N*D, C, H, W]
+          x = torch.transpose(x, 1, 2)
+          x = torch.reshape(x, new_shape)
+        x0  = self.in_conv(x)
+        x0b = self.bridge0(x0)
+        x1  = self.down1(x0)
+        x1b = self.bridge1(x1)
+        x2  = self.down2(x1)
+        x2b = self.bridge2(x2)
+        x3  = self.down3(x2)
+        x3b = self.bridge3(x3)
+        x4  = self.down4(x3)
+        x4  = self.aspp(x4) 
+
+        x   = self.up1(x4, x3b)
+        x   = self.up2(x, x2b)
+        x   = self.up3(x, x1b)
+        x   = self.up4(x, x0b)
+        output = self.out_conv(x)
+
+        if(len(x_shape) == 5):
+            new_shape = [N, D] + list(output.shape)[1:]
+            output    = torch.reshape(output, new_shape)
+            output    = torch.transpose(output, 1, 2)
+        return output