|
a |
|
b/pathflowai/fast_scnn.py |
|
|
1 |
########################################################################### |
|
|
2 |
# Created by: Tramac |
|
|
3 |
# Date: 2019-03-25 |
|
|
4 |
# Copyright (c) 2017 |
|
|
5 |
########################################################################### |
|
|
6 |
|
|
|
7 |
"""Fast Segmentation Convolutional Neural Network""" |
|
|
8 |
import os |
|
|
9 |
import torch |
|
|
10 |
import torch.nn as nn |
|
|
11 |
import torch.nn.functional as F |
|
|
12 |
|
|
|
13 |
__all__ = ['FastSCNN', 'get_fast_scnn'] |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class FastSCNN(nn.Module): |
|
|
17 |
def __init__(self, num_classes, aux=False, **kwargs): |
|
|
18 |
super(FastSCNN, self).__init__() |
|
|
19 |
self.aux = aux |
|
|
20 |
self.learning_to_downsample = LearningToDownsample(32, 48, 64) |
|
|
21 |
self.global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6, [3, 3, 3]) |
|
|
22 |
self.feature_fusion = FeatureFusionModule(64, 128, 128) |
|
|
23 |
self.classifier = Classifer(128, num_classes) |
|
|
24 |
if self.aux: |
|
|
25 |
self.auxlayer = nn.Sequential( |
|
|
26 |
nn.Conv2d(64, 32, 3, padding=1, bias=False), |
|
|
27 |
nn.BatchNorm2d(32), |
|
|
28 |
nn.ReLU(True), |
|
|
29 |
nn.Dropout(0.1), |
|
|
30 |
nn.Conv2d(32, num_classes, 1) |
|
|
31 |
) |
|
|
32 |
|
|
|
33 |
def forward(self, x): |
|
|
34 |
size = x.size()[2:] |
|
|
35 |
higher_res_features = self.learning_to_downsample(x) |
|
|
36 |
x = self.global_feature_extractor(higher_res_features) |
|
|
37 |
x = self.feature_fusion(higher_res_features, x) |
|
|
38 |
x = self.classifier(x) |
|
|
39 |
#outputs = [] |
|
|
40 |
x = F.interpolate(x, size, mode='bilinear', align_corners=True) |
|
|
41 |
#outputs.append(x) |
|
|
42 |
if self.aux: |
|
|
43 |
auxout = self.auxlayer(higher_res_features) |
|
|
44 |
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) |
|
|
45 |
outputs.append(auxout) |
|
|
46 |
return x#tuple(outputs) |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
class _ConvBNReLU(nn.Module): |
|
|
50 |
"""Conv-BN-ReLU""" |
|
|
51 |
|
|
|
52 |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, **kwargs): |
|
|
53 |
super(_ConvBNReLU, self).__init__() |
|
|
54 |
self.conv = nn.Sequential( |
|
|
55 |
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), |
|
|
56 |
nn.BatchNorm2d(out_channels), |
|
|
57 |
nn.ReLU(True) |
|
|
58 |
) |
|
|
59 |
|
|
|
60 |
def forward(self, x): |
|
|
61 |
return self.conv(x) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
class _DSConv(nn.Module): |
|
|
65 |
"""Depthwise Separable Convolutions""" |
|
|
66 |
|
|
|
67 |
def __init__(self, dw_channels, out_channels, stride=1, **kwargs): |
|
|
68 |
super(_DSConv, self).__init__() |
|
|
69 |
self.conv = nn.Sequential( |
|
|
70 |
nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False), |
|
|
71 |
nn.BatchNorm2d(dw_channels), |
|
|
72 |
nn.ReLU(True), |
|
|
73 |
nn.Conv2d(dw_channels, out_channels, 1, bias=False), |
|
|
74 |
nn.BatchNorm2d(out_channels), |
|
|
75 |
nn.ReLU(True) |
|
|
76 |
) |
|
|
77 |
|
|
|
78 |
def forward(self, x): |
|
|
79 |
return self.conv(x) |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
class _DWConv(nn.Module): |
|
|
83 |
def __init__(self, dw_channels, out_channels, stride=1, **kwargs): |
|
|
84 |
super(_DWConv, self).__init__() |
|
|
85 |
self.conv = nn.Sequential( |
|
|
86 |
nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False), |
|
|
87 |
nn.BatchNorm2d(out_channels), |
|
|
88 |
nn.ReLU(True) |
|
|
89 |
) |
|
|
90 |
|
|
|
91 |
def forward(self, x): |
|
|
92 |
return self.conv(x) |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
class LinearBottleneck(nn.Module): |
|
|
96 |
"""LinearBottleneck used in MobileNetV2""" |
|
|
97 |
|
|
|
98 |
def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs): |
|
|
99 |
super(LinearBottleneck, self).__init__() |
|
|
100 |
self.use_shortcut = stride == 1 and in_channels == out_channels |
|
|
101 |
self.block = nn.Sequential( |
|
|
102 |
# pw |
|
|
103 |
_ConvBNReLU(in_channels, in_channels * t, 1), |
|
|
104 |
# dw |
|
|
105 |
_DWConv(in_channels * t, in_channels * t, stride), |
|
|
106 |
# pw-linear |
|
|
107 |
nn.Conv2d(in_channels * t, out_channels, 1, bias=False), |
|
|
108 |
nn.BatchNorm2d(out_channels) |
|
|
109 |
) |
|
|
110 |
|
|
|
111 |
def forward(self, x): |
|
|
112 |
out = self.block(x) |
|
|
113 |
if self.use_shortcut: |
|
|
114 |
out = x + out |
|
|
115 |
return out |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
class PyramidPooling(nn.Module): |
|
|
119 |
"""Pyramid pooling module""" |
|
|
120 |
|
|
|
121 |
def __init__(self, in_channels, out_channels, **kwargs): |
|
|
122 |
super(PyramidPooling, self).__init__() |
|
|
123 |
inter_channels = int(in_channels / 4) |
|
|
124 |
self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) |
|
|
125 |
self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) |
|
|
126 |
self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) |
|
|
127 |
self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) |
|
|
128 |
self.out = _ConvBNReLU(in_channels * 2, out_channels, 1) |
|
|
129 |
|
|
|
130 |
def pool(self, x, size): |
|
|
131 |
avgpool = nn.AdaptiveAvgPool2d(size) |
|
|
132 |
return avgpool(x) |
|
|
133 |
|
|
|
134 |
def upsample(self, x, size): |
|
|
135 |
return F.interpolate(x, size, mode='bilinear', align_corners=True) |
|
|
136 |
|
|
|
137 |
def forward(self, x): |
|
|
138 |
size = x.size()[2:] |
|
|
139 |
feat1 = self.upsample(self.conv1(self.pool(x, 1)), size) |
|
|
140 |
feat2 = self.upsample(self.conv2(self.pool(x, 2)), size) |
|
|
141 |
feat3 = self.upsample(self.conv3(self.pool(x, 3)), size) |
|
|
142 |
feat4 = self.upsample(self.conv4(self.pool(x, 6)), size) |
|
|
143 |
x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1) |
|
|
144 |
x = self.out(x) |
|
|
145 |
return x |
|
|
146 |
|
|
|
147 |
|
|
|
148 |
class LearningToDownsample(nn.Module): |
|
|
149 |
"""Learning to downsample module""" |
|
|
150 |
|
|
|
151 |
def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64, **kwargs): |
|
|
152 |
super(LearningToDownsample, self).__init__() |
|
|
153 |
self.conv = _ConvBNReLU(3, dw_channels1, 3, 2) |
|
|
154 |
self.dsconv1 = _DSConv(dw_channels1, dw_channels2, 2) |
|
|
155 |
self.dsconv2 = _DSConv(dw_channels2, out_channels, 2) |
|
|
156 |
|
|
|
157 |
def forward(self, x): |
|
|
158 |
x = self.conv(x) |
|
|
159 |
x = self.dsconv1(x) |
|
|
160 |
x = self.dsconv2(x) |
|
|
161 |
return x |
|
|
162 |
|
|
|
163 |
|
|
|
164 |
class GlobalFeatureExtractor(nn.Module): |
|
|
165 |
"""Global feature extractor module""" |
|
|
166 |
|
|
|
167 |
def __init__(self, in_channels=64, block_channels=(64, 96, 128), |
|
|
168 |
out_channels=128, t=6, num_blocks=(3, 3, 3), **kwargs): |
|
|
169 |
super(GlobalFeatureExtractor, self).__init__() |
|
|
170 |
self.bottleneck1 = self._make_layer(LinearBottleneck, in_channels, block_channels[0], num_blocks[0], t, 2) |
|
|
171 |
self.bottleneck2 = self._make_layer(LinearBottleneck, block_channels[0], block_channels[1], num_blocks[1], t, 2) |
|
|
172 |
self.bottleneck3 = self._make_layer(LinearBottleneck, block_channels[1], block_channels[2], num_blocks[2], t, 1) |
|
|
173 |
self.ppm = PyramidPooling(block_channels[2], out_channels) |
|
|
174 |
|
|
|
175 |
def _make_layer(self, block, inplanes, planes, blocks, t=6, stride=1): |
|
|
176 |
layers = [] |
|
|
177 |
layers.append(block(inplanes, planes, t, stride)) |
|
|
178 |
for i in range(1, blocks): |
|
|
179 |
layers.append(block(planes, planes, t, 1)) |
|
|
180 |
return nn.Sequential(*layers) |
|
|
181 |
|
|
|
182 |
def forward(self, x): |
|
|
183 |
x = self.bottleneck1(x) |
|
|
184 |
x = self.bottleneck2(x) |
|
|
185 |
x = self.bottleneck3(x) |
|
|
186 |
x = self.ppm(x) |
|
|
187 |
return x |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
class FeatureFusionModule(nn.Module): |
|
|
191 |
"""Feature fusion module""" |
|
|
192 |
|
|
|
193 |
def __init__(self, highter_in_channels, lower_in_channels, out_channels, scale_factor=4, **kwargs): |
|
|
194 |
super(FeatureFusionModule, self).__init__() |
|
|
195 |
self.scale_factor = scale_factor |
|
|
196 |
self.dwconv = _DWConv(lower_in_channels, out_channels, 1) |
|
|
197 |
self.conv_lower_res = nn.Sequential( |
|
|
198 |
nn.Conv2d(out_channels, out_channels, 1), |
|
|
199 |
nn.BatchNorm2d(out_channels) |
|
|
200 |
) |
|
|
201 |
self.conv_higher_res = nn.Sequential( |
|
|
202 |
nn.Conv2d(highter_in_channels, out_channels, 1), |
|
|
203 |
nn.BatchNorm2d(out_channels) |
|
|
204 |
) |
|
|
205 |
self.relu = nn.ReLU(True) |
|
|
206 |
|
|
|
207 |
def forward(self, higher_res_feature, lower_res_feature): |
|
|
208 |
lower_res_feature = F.interpolate(lower_res_feature, scale_factor=4, mode='bilinear', align_corners=True) |
|
|
209 |
lower_res_feature = self.dwconv(lower_res_feature) |
|
|
210 |
lower_res_feature = self.conv_lower_res(lower_res_feature) |
|
|
211 |
|
|
|
212 |
higher_res_feature = self.conv_higher_res(higher_res_feature) |
|
|
213 |
out = higher_res_feature + lower_res_feature |
|
|
214 |
return self.relu(out) |
|
|
215 |
|
|
|
216 |
|
|
|
217 |
class Classifer(nn.Module): |
|
|
218 |
"""Classifer""" |
|
|
219 |
|
|
|
220 |
def __init__(self, dw_channels, num_classes, stride=1, **kwargs): |
|
|
221 |
super(Classifer, self).__init__() |
|
|
222 |
self.dsconv1 = _DSConv(dw_channels, dw_channels, stride) |
|
|
223 |
self.dsconv2 = _DSConv(dw_channels, dw_channels, stride) |
|
|
224 |
self.conv = nn.Sequential( |
|
|
225 |
nn.Dropout(0.1), |
|
|
226 |
nn.Conv2d(dw_channels, num_classes, 1) |
|
|
227 |
) |
|
|
228 |
|
|
|
229 |
def forward(self, x): |
|
|
230 |
x = self.dsconv1(x) |
|
|
231 |
x = self.dsconv2(x) |
|
|
232 |
x = self.conv(x) |
|
|
233 |
return x |
|
|
234 |
|
|
|
235 |
|
|
|
236 |
def get_fast_scnn(n_classes, **kwargs): |
|
|
237 |
|
|
|
238 |
model = FastSCNN(n_classes, **kwargs) |
|
|
239 |
|
|
|
240 |
return model |