Diff of /coplenet.py [000000] .. [5b2be7]

Switch to unified view

a b/coplenet.py
1
# -*- coding: utf-8 -*-
2
# Author: Guotai Wang
3
# Date:   12 June, 2020
4
# Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images.
5
# Reference: 
6
#     G. Wang et al. A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions 
7
#     from CT Images. IEEE Transactions on Medical Imaging, 2020. DOI:10.1109/TMI.2020.3000314.
8
9
from __future__ import print_function, division
10
import torch
11
import torch.nn as nn
12
13
class ConvLayer(nn.Module):
14
    def __init__(self, in_channels, out_channels, kernel_size = 1):
15
        super(ConvLayer, self).__init__()
16
        padding = int((kernel_size - 1) / 2)
17
        self.conv = nn.Sequential(
18
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
19
            nn.BatchNorm2d(out_channels),
20
            nn.LeakyReLU()
21
        )
22
       
23
    def forward(self, x):
24
        return self.conv(x)
25
26
class SEBlock(nn.Module):
27
    def __init__(self, in_channels, r):
28
        super(SEBlock, self).__init__()
29
30
        redu_chns = int(in_channels / r)
31
        self.se_layers = nn.Sequential(
32
            nn.AdaptiveAvgPool2d(1),
33
            nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0),
34
            nn.LeakyReLU(),
35
            nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0),
36
            nn.ReLU())
37
        
38
    def forward(self, x):
39
        f = self.se_layers(x)
40
        return f*x + x
41
42
class ASPPBlock(nn.Module):
43
    def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list):
44
        super(ASPPBlock, self).__init__()
45
        self.conv_num = len(out_channels_list)
46
        assert(self.conv_num == 4)
47
        assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list))
48
        pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0])
49
        pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1])
50
        pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2])
51
        pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3])
52
        self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], 
53
                    dilation = dilation_list[0], padding = pad0 )
54
        self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], 
55
                    dilation = dilation_list[1], padding = pad1 )
56
        self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], 
57
                    dilation = dilation_list[2], padding = pad2 )
58
        self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], 
59
                    dilation = dilation_list[3], padding = pad3 )
60
61
        out_channels  = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] 
62
        self.conv_1x1 = nn.Sequential(
63
            nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0),
64
            nn.BatchNorm2d(out_channels),
65
            nn.LeakyReLU())
66
       
67
    def forward(self, x):
68
        x1 = self.conv_1(x)
69
        x2 = self.conv_2(x)
70
        x3 = self.conv_3(x)
71
        x4 = self.conv_4(x)
72
73
        y  = torch.cat([x1, x2, x3, x4], dim=1)
74
        y  = self.conv_1x1(y)
75
        return y
76
77
class ConvBNActBlock(nn.Module):
78
    """Two convolution layers with batch norm, leaky relu, dropout and SE block"""
79
    def __init__(self,in_channels, out_channels, dropout_p):
80
        super(ConvBNActBlock, self).__init__()
81
        self.conv_conv = nn.Sequential(
82
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
83
            nn.BatchNorm2d(out_channels),
84
            nn.LeakyReLU(),
85
            nn.Dropout(dropout_p),
86
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
87
            nn.BatchNorm2d(out_channels),
88
            nn.LeakyReLU(),
89
            SEBlock(out_channels, 2)
90
        )
91
       
92
    def forward(self, x):
93
        return self.conv_conv(x)
94
95
class DownBlock(nn.Module):
96
    """Downsampling by a concantenation of max-pool and avg-pool, followed by ConvBNActBlock
97
    """
98
    def __init__(self, in_channels, out_channels, dropout_p):
99
        super(DownBlock, self).__init__()
100
        self.maxpool = nn.MaxPool2d(2)
101
        self.avgpool = nn.AvgPool2d(2)
102
        self.conv    = ConvBNActBlock(2 * in_channels, out_channels, dropout_p)
103
        
104
    def forward(self, x):
105
        x_max = self.maxpool(x)
106
        x_avg = self.avgpool(x)
107
        x_cat = torch.cat([x_max, x_avg], dim=1)
108
        y     = self.conv(x_cat)
109
        return y + x_cat
110
111
class UpBlock(nn.Module):
112
    """Upssampling followed by ConvBNActBlock"""
113
    def __init__(self, in_channels1, in_channels2, out_channels, 
114
                 bilinear=True, dropout_p = 0.5):
115
        super(UpBlock, self).__init__()
116
        self.bilinear = bilinear
117
        if bilinear:
118
            self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1)
119
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
120
        else:
121
            self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2)
122
        self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p)
123
124
    def forward(self, x1, x2):
125
        if self.bilinear:
126
            x1 = self.conv1x1(x1)
127
        x1    = self.up(x1)
128
        x_cat = torch.cat([x2, x1], dim=1)
129
        y     = self.conv(x_cat)
130
        return y + x_cat
131
132
class COPLENet(nn.Module):
133
    def __init__(self, params):
134
        super(COPLENet, self).__init__()
135
        self.params    = params
136
        self.in_chns   = self.params['in_chns']
137
        self.ft_chns   = self.params['feature_chns']
138
        self.n_class   = self.params['class_num']
139
        self.bilinear  = self.params['bilinear']
140
        self.dropout   = self.params['dropout']
141
        assert(len(self.ft_chns) == 5)
142
143
        f0_half = int(self.ft_chns[0] / 2)
144
        f1_half = int(self.ft_chns[1] / 2)
145
        f2_half = int(self.ft_chns[2] / 2)
146
        f3_half = int(self.ft_chns[3] / 2)
147
        self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
148
        self.down1  = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
149
        self.down2  = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
150
        self.down3  = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
151
        self.down4  = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
152
        
153
        self.bridge0= ConvLayer(self.ft_chns[0], f0_half)
154
        self.bridge1= ConvLayer(self.ft_chns[1], f1_half)
155
        self.bridge2= ConvLayer(self.ft_chns[2], f2_half)
156
        self.bridge3= ConvLayer(self.ft_chns[3], f3_half)
157
158
        self.up1    = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3])
159
        self.up2    = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2])
160
        self.up3    = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1])
161
        self.up4    = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0])
162
163
        f4 = self.ft_chns[4]
164
        aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)]
165
        aspp_knls = [1, 3, 3, 3]
166
        aspp_dila = [1, 2, 4, 6]
167
        self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila)
168
        
169
            
170
        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,  
171
            kernel_size = 3, padding = 1)
172
173
    def forward(self, x):
174
        x_shape = list(x.shape)
175
        if(len(x_shape) == 5):
176
          [N, C, D, H, W] = x_shape
177
          new_shape = [N*D, C, H, W]
178
          x = torch.transpose(x, 1, 2)
179
          x = torch.reshape(x, new_shape)
180
        x0  = self.in_conv(x)
181
        x0b = self.bridge0(x0)
182
        x1  = self.down1(x0)
183
        x1b = self.bridge1(x1)
184
        x2  = self.down2(x1)
185
        x2b = self.bridge2(x2)
186
        x3  = self.down3(x2)
187
        x3b = self.bridge3(x3)
188
        x4  = self.down4(x3)
189
        x4  = self.aspp(x4) 
190
191
        x   = self.up1(x4, x3b)
192
        x   = self.up2(x, x2b)
193
        x   = self.up3(x, x1b)
194
        x   = self.up4(x, x0b)
195
        output = self.out_conv(x)
196
197
        if(len(x_shape) == 5):
198
            new_shape = [N, D] + list(output.shape)[1:]
199
            output    = torch.reshape(output, new_shape)
200
            output    = torch.transpose(output, 1, 2)
201
        return output