|
a |
|
b/coplenet.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
# Author: Guotai Wang |
|
|
3 |
# Date: 12 June, 2020 |
|
|
4 |
# Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images. |
|
|
5 |
# Reference: |
|
|
6 |
# G. Wang et al. A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions |
|
|
7 |
# from CT Images. IEEE Transactions on Medical Imaging, 2020. DOI:10.1109/TMI.2020.3000314. |
|
|
8 |
|
|
|
9 |
from __future__ import print_function, division |
|
|
10 |
import torch |
|
|
11 |
import torch.nn as nn |
|
|
12 |
|
|
|
13 |
class ConvLayer(nn.Module): |
|
|
14 |
def __init__(self, in_channels, out_channels, kernel_size = 1): |
|
|
15 |
super(ConvLayer, self).__init__() |
|
|
16 |
padding = int((kernel_size - 1) / 2) |
|
|
17 |
self.conv = nn.Sequential( |
|
|
18 |
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), |
|
|
19 |
nn.BatchNorm2d(out_channels), |
|
|
20 |
nn.LeakyReLU() |
|
|
21 |
) |
|
|
22 |
|
|
|
23 |
def forward(self, x): |
|
|
24 |
return self.conv(x) |
|
|
25 |
|
|
|
26 |
class SEBlock(nn.Module): |
|
|
27 |
def __init__(self, in_channels, r): |
|
|
28 |
super(SEBlock, self).__init__() |
|
|
29 |
|
|
|
30 |
redu_chns = int(in_channels / r) |
|
|
31 |
self.se_layers = nn.Sequential( |
|
|
32 |
nn.AdaptiveAvgPool2d(1), |
|
|
33 |
nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0), |
|
|
34 |
nn.LeakyReLU(), |
|
|
35 |
nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0), |
|
|
36 |
nn.ReLU()) |
|
|
37 |
|
|
|
38 |
def forward(self, x): |
|
|
39 |
f = self.se_layers(x) |
|
|
40 |
return f*x + x |
|
|
41 |
|
|
|
42 |
class ASPPBlock(nn.Module): |
|
|
43 |
def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): |
|
|
44 |
super(ASPPBlock, self).__init__() |
|
|
45 |
self.conv_num = len(out_channels_list) |
|
|
46 |
assert(self.conv_num == 4) |
|
|
47 |
assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list)) |
|
|
48 |
pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0]) |
|
|
49 |
pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1]) |
|
|
50 |
pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2]) |
|
|
51 |
pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3]) |
|
|
52 |
self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], |
|
|
53 |
dilation = dilation_list[0], padding = pad0 ) |
|
|
54 |
self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], |
|
|
55 |
dilation = dilation_list[1], padding = pad1 ) |
|
|
56 |
self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], |
|
|
57 |
dilation = dilation_list[2], padding = pad2 ) |
|
|
58 |
self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], |
|
|
59 |
dilation = dilation_list[3], padding = pad3 ) |
|
|
60 |
|
|
|
61 |
out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] |
|
|
62 |
self.conv_1x1 = nn.Sequential( |
|
|
63 |
nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0), |
|
|
64 |
nn.BatchNorm2d(out_channels), |
|
|
65 |
nn.LeakyReLU()) |
|
|
66 |
|
|
|
67 |
def forward(self, x): |
|
|
68 |
x1 = self.conv_1(x) |
|
|
69 |
x2 = self.conv_2(x) |
|
|
70 |
x3 = self.conv_3(x) |
|
|
71 |
x4 = self.conv_4(x) |
|
|
72 |
|
|
|
73 |
y = torch.cat([x1, x2, x3, x4], dim=1) |
|
|
74 |
y = self.conv_1x1(y) |
|
|
75 |
return y |
|
|
76 |
|
|
|
77 |
class ConvBNActBlock(nn.Module): |
|
|
78 |
"""Two convolution layers with batch norm, leaky relu, dropout and SE block""" |
|
|
79 |
def __init__(self,in_channels, out_channels, dropout_p): |
|
|
80 |
super(ConvBNActBlock, self).__init__() |
|
|
81 |
self.conv_conv = nn.Sequential( |
|
|
82 |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
|
83 |
nn.BatchNorm2d(out_channels), |
|
|
84 |
nn.LeakyReLU(), |
|
|
85 |
nn.Dropout(dropout_p), |
|
|
86 |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
|
87 |
nn.BatchNorm2d(out_channels), |
|
|
88 |
nn.LeakyReLU(), |
|
|
89 |
SEBlock(out_channels, 2) |
|
|
90 |
) |
|
|
91 |
|
|
|
92 |
def forward(self, x): |
|
|
93 |
return self.conv_conv(x) |
|
|
94 |
|
|
|
95 |
class DownBlock(nn.Module): |
|
|
96 |
"""Downsampling by a concantenation of max-pool and avg-pool, followed by ConvBNActBlock |
|
|
97 |
""" |
|
|
98 |
def __init__(self, in_channels, out_channels, dropout_p): |
|
|
99 |
super(DownBlock, self).__init__() |
|
|
100 |
self.maxpool = nn.MaxPool2d(2) |
|
|
101 |
self.avgpool = nn.AvgPool2d(2) |
|
|
102 |
self.conv = ConvBNActBlock(2 * in_channels, out_channels, dropout_p) |
|
|
103 |
|
|
|
104 |
def forward(self, x): |
|
|
105 |
x_max = self.maxpool(x) |
|
|
106 |
x_avg = self.avgpool(x) |
|
|
107 |
x_cat = torch.cat([x_max, x_avg], dim=1) |
|
|
108 |
y = self.conv(x_cat) |
|
|
109 |
return y + x_cat |
|
|
110 |
|
|
|
111 |
class UpBlock(nn.Module): |
|
|
112 |
"""Upssampling followed by ConvBNActBlock""" |
|
|
113 |
def __init__(self, in_channels1, in_channels2, out_channels, |
|
|
114 |
bilinear=True, dropout_p = 0.5): |
|
|
115 |
super(UpBlock, self).__init__() |
|
|
116 |
self.bilinear = bilinear |
|
|
117 |
if bilinear: |
|
|
118 |
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) |
|
|
119 |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
|
120 |
else: |
|
|
121 |
self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) |
|
|
122 |
self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) |
|
|
123 |
|
|
|
124 |
def forward(self, x1, x2): |
|
|
125 |
if self.bilinear: |
|
|
126 |
x1 = self.conv1x1(x1) |
|
|
127 |
x1 = self.up(x1) |
|
|
128 |
x_cat = torch.cat([x2, x1], dim=1) |
|
|
129 |
y = self.conv(x_cat) |
|
|
130 |
return y + x_cat |
|
|
131 |
|
|
|
132 |
class COPLENet(nn.Module): |
|
|
133 |
def __init__(self, params): |
|
|
134 |
super(COPLENet, self).__init__() |
|
|
135 |
self.params = params |
|
|
136 |
self.in_chns = self.params['in_chns'] |
|
|
137 |
self.ft_chns = self.params['feature_chns'] |
|
|
138 |
self.n_class = self.params['class_num'] |
|
|
139 |
self.bilinear = self.params['bilinear'] |
|
|
140 |
self.dropout = self.params['dropout'] |
|
|
141 |
assert(len(self.ft_chns) == 5) |
|
|
142 |
|
|
|
143 |
f0_half = int(self.ft_chns[0] / 2) |
|
|
144 |
f1_half = int(self.ft_chns[1] / 2) |
|
|
145 |
f2_half = int(self.ft_chns[2] / 2) |
|
|
146 |
f3_half = int(self.ft_chns[3] / 2) |
|
|
147 |
self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) |
|
|
148 |
self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) |
|
|
149 |
self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) |
|
|
150 |
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) |
|
|
151 |
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) |
|
|
152 |
|
|
|
153 |
self.bridge0= ConvLayer(self.ft_chns[0], f0_half) |
|
|
154 |
self.bridge1= ConvLayer(self.ft_chns[1], f1_half) |
|
|
155 |
self.bridge2= ConvLayer(self.ft_chns[2], f2_half) |
|
|
156 |
self.bridge3= ConvLayer(self.ft_chns[3], f3_half) |
|
|
157 |
|
|
|
158 |
self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) |
|
|
159 |
self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) |
|
|
160 |
self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) |
|
|
161 |
self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) |
|
|
162 |
|
|
|
163 |
f4 = self.ft_chns[4] |
|
|
164 |
aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] |
|
|
165 |
aspp_knls = [1, 3, 3, 3] |
|
|
166 |
aspp_dila = [1, 2, 4, 6] |
|
|
167 |
self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila) |
|
|
168 |
|
|
|
169 |
|
|
|
170 |
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, |
|
|
171 |
kernel_size = 3, padding = 1) |
|
|
172 |
|
|
|
173 |
def forward(self, x): |
|
|
174 |
x_shape = list(x.shape) |
|
|
175 |
if(len(x_shape) == 5): |
|
|
176 |
[N, C, D, H, W] = x_shape |
|
|
177 |
new_shape = [N*D, C, H, W] |
|
|
178 |
x = torch.transpose(x, 1, 2) |
|
|
179 |
x = torch.reshape(x, new_shape) |
|
|
180 |
x0 = self.in_conv(x) |
|
|
181 |
x0b = self.bridge0(x0) |
|
|
182 |
x1 = self.down1(x0) |
|
|
183 |
x1b = self.bridge1(x1) |
|
|
184 |
x2 = self.down2(x1) |
|
|
185 |
x2b = self.bridge2(x2) |
|
|
186 |
x3 = self.down3(x2) |
|
|
187 |
x3b = self.bridge3(x3) |
|
|
188 |
x4 = self.down4(x3) |
|
|
189 |
x4 = self.aspp(x4) |
|
|
190 |
|
|
|
191 |
x = self.up1(x4, x3b) |
|
|
192 |
x = self.up2(x, x2b) |
|
|
193 |
x = self.up3(x, x1b) |
|
|
194 |
x = self.up4(x, x0b) |
|
|
195 |
output = self.out_conv(x) |
|
|
196 |
|
|
|
197 |
if(len(x_shape) == 5): |
|
|
198 |
new_shape = [N, D] + list(output.shape)[1:] |
|
|
199 |
output = torch.reshape(output, new_shape) |
|
|
200 |
output = torch.transpose(output, 1, 2) |
|
|
201 |
return output |