--- a +++ b/coplenet.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +# Author: Guotai Wang +# Date: 12 June, 2020 +# Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images. +# Reference: +# G. Wang et al. A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions +# from CT Images. IEEE Transactions on Medical Imaging, 2020. DOI:10.1109/TMI.2020.3000314. + +from __future__ import print_function, division +import torch +import torch.nn as nn + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size = 1): + super(ConvLayer, self).__init__() + padding = int((kernel_size - 1) / 2) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU() + ) + + def forward(self, x): + return self.conv(x) + +class SEBlock(nn.Module): + def __init__(self, in_channels, r): + super(SEBlock, self).__init__() + + redu_chns = int(in_channels / r) + self.se_layers = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0), + nn.LeakyReLU(), + nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0), + nn.ReLU()) + + def forward(self, x): + f = self.se_layers(x) + return f*x + x + +class ASPPBlock(nn.Module): + def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): + super(ASPPBlock, self).__init__() + self.conv_num = len(out_channels_list) + assert(self.conv_num == 4) + assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list)) + pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0]) + pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1]) + pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2]) + pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3]) + self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], + dilation = dilation_list[0], padding = pad0 ) + self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], + dilation = dilation_list[1], padding = pad1 ) + self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], + dilation = dilation_list[2], padding = pad2 ) + self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], + dilation = dilation_list[3], padding = pad3 ) + + out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] + self.conv_1x1 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU()) + + def forward(self, x): + x1 = self.conv_1(x) + x2 = self.conv_2(x) + x3 = self.conv_3(x) + x4 = self.conv_4(x) + + y = torch.cat([x1, x2, x3, x4], dim=1) + y = self.conv_1x1(y) + return y + +class ConvBNActBlock(nn.Module): + """Two convolution layers with batch norm, leaky relu, dropout and SE block""" + def __init__(self,in_channels, out_channels, dropout_p): + super(ConvBNActBlock, self).__init__() + self.conv_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(), + SEBlock(out_channels, 2) + ) + + def forward(self, x): + return self.conv_conv(x) + +class DownBlock(nn.Module): + """Downsampling by a concantenation of max-pool and avg-pool, followed by ConvBNActBlock + """ + def __init__(self, in_channels, out_channels, dropout_p): + super(DownBlock, self).__init__() + self.maxpool = nn.MaxPool2d(2) + self.avgpool = nn.AvgPool2d(2) + self.conv = ConvBNActBlock(2 * in_channels, out_channels, dropout_p) + + def forward(self, x): + x_max = self.maxpool(x) + x_avg = self.avgpool(x) + x_cat = torch.cat([x_max, x_avg], dim=1) + y = self.conv(x_cat) + return y + x_cat + +class UpBlock(nn.Module): + """Upssampling followed by ConvBNActBlock""" + def __init__(self, in_channels1, in_channels2, out_channels, + bilinear=True, dropout_p = 0.5): + super(UpBlock, self).__init__() + self.bilinear = bilinear + if bilinear: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) + + def forward(self, x1, x2): + if self.bilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x_cat = torch.cat([x2, x1], dim=1) + y = self.conv(x_cat) + return y + x_cat + +class COPLENet(nn.Module): + def __init__(self, params): + super(COPLENet, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5) + + f0_half = int(self.ft_chns[0] / 2) + f1_half = int(self.ft_chns[1] / 2) + f2_half = int(self.ft_chns[2] / 2) + f3_half = int(self.ft_chns[3] / 2) + self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + self.bridge0= ConvLayer(self.ft_chns[0], f0_half) + self.bridge1= ConvLayer(self.ft_chns[1], f1_half) + self.bridge2= ConvLayer(self.ft_chns[2], f2_half) + self.bridge3= ConvLayer(self.ft_chns[3], f3_half) + + self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) + + f4 = self.ft_chns[4] + aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] + aspp_knls = [1, 3, 3, 3] + aspp_dila = [1, 2, 4, 6] + self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila) + + + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, + kernel_size = 3, padding = 1) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x0 = self.in_conv(x) + x0b = self.bridge0(x0) + x1 = self.down1(x0) + x1b = self.bridge1(x1) + x2 = self.down2(x1) + x2b = self.bridge2(x2) + x3 = self.down3(x2) + x3b = self.bridge3(x3) + x4 = self.down4(x3) + x4 = self.aspp(x4) + + x = self.up1(x4, x3b) + x = self.up2(x, x2b) + x = self.up3(x, x1b) + x = self.up4(x, x0b) + output = self.out_conv(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(output.shape)[1:] + output = torch.reshape(output, new_shape) + output = torch.transpose(output, 1, 2) + return output