Diff of /pathflowai/fast_scnn.py [000000] .. [e9500f]

Switch to unified view

a b/pathflowai/fast_scnn.py
1
###########################################################################
2
# Created by: Tramac
3
# Date: 2019-03-25
4
# Copyright (c) 2017
5
###########################################################################
6
7
"""Fast Segmentation Convolutional Neural Network"""
8
import os
9
import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
13
__all__ = ['FastSCNN', 'get_fast_scnn']
14
15
16
class FastSCNN(nn.Module):
17
    def __init__(self, num_classes, aux=False, **kwargs):
18
        super(FastSCNN, self).__init__()
19
        self.aux = aux
20
        self.learning_to_downsample = LearningToDownsample(32, 48, 64)
21
        self.global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6, [3, 3, 3])
22
        self.feature_fusion = FeatureFusionModule(64, 128, 128)
23
        self.classifier = Classifer(128, num_classes)
24
        if self.aux:
25
            self.auxlayer = nn.Sequential(
26
                nn.Conv2d(64, 32, 3, padding=1, bias=False),
27
                nn.BatchNorm2d(32),
28
                nn.ReLU(True),
29
                nn.Dropout(0.1),
30
                nn.Conv2d(32, num_classes, 1)
31
            )
32
33
    def forward(self, x):
34
        size = x.size()[2:]
35
        higher_res_features = self.learning_to_downsample(x)
36
        x = self.global_feature_extractor(higher_res_features)
37
        x = self.feature_fusion(higher_res_features, x)
38
        x = self.classifier(x)
39
        #outputs = []
40
        x = F.interpolate(x, size, mode='bilinear', align_corners=True)
41
        #outputs.append(x)
42
        if self.aux:
43
            auxout = self.auxlayer(higher_res_features)
44
            auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
45
            outputs.append(auxout)
46
        return x#tuple(outputs)
47
48
49
class _ConvBNReLU(nn.Module):
50
    """Conv-BN-ReLU"""
51
52
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, **kwargs):
53
        super(_ConvBNReLU, self).__init__()
54
        self.conv = nn.Sequential(
55
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
56
            nn.BatchNorm2d(out_channels),
57
            nn.ReLU(True)
58
        )
59
60
    def forward(self, x):
61
        return self.conv(x)
62
63
64
class _DSConv(nn.Module):
65
    """Depthwise Separable Convolutions"""
66
67
    def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
68
        super(_DSConv, self).__init__()
69
        self.conv = nn.Sequential(
70
            nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False),
71
            nn.BatchNorm2d(dw_channels),
72
            nn.ReLU(True),
73
            nn.Conv2d(dw_channels, out_channels, 1, bias=False),
74
            nn.BatchNorm2d(out_channels),
75
            nn.ReLU(True)
76
        )
77
78
    def forward(self, x):
79
        return self.conv(x)
80
81
82
class _DWConv(nn.Module):
83
    def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
84
        super(_DWConv, self).__init__()
85
        self.conv = nn.Sequential(
86
            nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False),
87
            nn.BatchNorm2d(out_channels),
88
            nn.ReLU(True)
89
        )
90
91
    def forward(self, x):
92
        return self.conv(x)
93
94
95
class LinearBottleneck(nn.Module):
96
    """LinearBottleneck used in MobileNetV2"""
97
98
    def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs):
99
        super(LinearBottleneck, self).__init__()
100
        self.use_shortcut = stride == 1 and in_channels == out_channels
101
        self.block = nn.Sequential(
102
            # pw
103
            _ConvBNReLU(in_channels, in_channels * t, 1),
104
            # dw
105
            _DWConv(in_channels * t, in_channels * t, stride),
106
            # pw-linear
107
            nn.Conv2d(in_channels * t, out_channels, 1, bias=False),
108
            nn.BatchNorm2d(out_channels)
109
        )
110
111
    def forward(self, x):
112
        out = self.block(x)
113
        if self.use_shortcut:
114
            out = x + out
115
        return out
116
117
118
class PyramidPooling(nn.Module):
119
    """Pyramid pooling module"""
120
121
    def __init__(self, in_channels, out_channels, **kwargs):
122
        super(PyramidPooling, self).__init__()
123
        inter_channels = int(in_channels / 4)
124
        self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
125
        self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
126
        self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
127
        self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
128
        self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)
129
130
    def pool(self, x, size):
131
        avgpool = nn.AdaptiveAvgPool2d(size)
132
        return avgpool(x)
133
134
    def upsample(self, x, size):
135
        return F.interpolate(x, size, mode='bilinear', align_corners=True)
136
137
    def forward(self, x):
138
        size = x.size()[2:]
139
        feat1 = self.upsample(self.conv1(self.pool(x, 1)), size)
140
        feat2 = self.upsample(self.conv2(self.pool(x, 2)), size)
141
        feat3 = self.upsample(self.conv3(self.pool(x, 3)), size)
142
        feat4 = self.upsample(self.conv4(self.pool(x, 6)), size)
143
        x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
144
        x = self.out(x)
145
        return x
146
147
148
class LearningToDownsample(nn.Module):
149
    """Learning to downsample module"""
150
151
    def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64, **kwargs):
152
        super(LearningToDownsample, self).__init__()
153
        self.conv = _ConvBNReLU(3, dw_channels1, 3, 2)
154
        self.dsconv1 = _DSConv(dw_channels1, dw_channels2, 2)
155
        self.dsconv2 = _DSConv(dw_channels2, out_channels, 2)
156
157
    def forward(self, x):
158
        x = self.conv(x)
159
        x = self.dsconv1(x)
160
        x = self.dsconv2(x)
161
        return x
162
163
164
class GlobalFeatureExtractor(nn.Module):
165
    """Global feature extractor module"""
166
167
    def __init__(self, in_channels=64, block_channels=(64, 96, 128),
168
                 out_channels=128, t=6, num_blocks=(3, 3, 3), **kwargs):
169
        super(GlobalFeatureExtractor, self).__init__()
170
        self.bottleneck1 = self._make_layer(LinearBottleneck, in_channels, block_channels[0], num_blocks[0], t, 2)
171
        self.bottleneck2 = self._make_layer(LinearBottleneck, block_channels[0], block_channels[1], num_blocks[1], t, 2)
172
        self.bottleneck3 = self._make_layer(LinearBottleneck, block_channels[1], block_channels[2], num_blocks[2], t, 1)
173
        self.ppm = PyramidPooling(block_channels[2], out_channels)
174
175
    def _make_layer(self, block, inplanes, planes, blocks, t=6, stride=1):
176
        layers = []
177
        layers.append(block(inplanes, planes, t, stride))
178
        for i in range(1, blocks):
179
            layers.append(block(planes, planes, t, 1))
180
        return nn.Sequential(*layers)
181
182
    def forward(self, x):
183
        x = self.bottleneck1(x)
184
        x = self.bottleneck2(x)
185
        x = self.bottleneck3(x)
186
        x = self.ppm(x)
187
        return x
188
189
190
class FeatureFusionModule(nn.Module):
191
    """Feature fusion module"""
192
193
    def __init__(self, highter_in_channels, lower_in_channels, out_channels, scale_factor=4, **kwargs):
194
        super(FeatureFusionModule, self).__init__()
195
        self.scale_factor = scale_factor
196
        self.dwconv = _DWConv(lower_in_channels, out_channels, 1)
197
        self.conv_lower_res = nn.Sequential(
198
            nn.Conv2d(out_channels, out_channels, 1),
199
            nn.BatchNorm2d(out_channels)
200
        )
201
        self.conv_higher_res = nn.Sequential(
202
            nn.Conv2d(highter_in_channels, out_channels, 1),
203
            nn.BatchNorm2d(out_channels)
204
        )
205
        self.relu = nn.ReLU(True)
206
207
    def forward(self, higher_res_feature, lower_res_feature):
208
        lower_res_feature = F.interpolate(lower_res_feature, scale_factor=4, mode='bilinear', align_corners=True)
209
        lower_res_feature = self.dwconv(lower_res_feature)
210
        lower_res_feature = self.conv_lower_res(lower_res_feature)
211
212
        higher_res_feature = self.conv_higher_res(higher_res_feature)
213
        out = higher_res_feature + lower_res_feature
214
        return self.relu(out)
215
216
217
class Classifer(nn.Module):
218
    """Classifer"""
219
220
    def __init__(self, dw_channels, num_classes, stride=1, **kwargs):
221
        super(Classifer, self).__init__()
222
        self.dsconv1 = _DSConv(dw_channels, dw_channels, stride)
223
        self.dsconv2 = _DSConv(dw_channels, dw_channels, stride)
224
        self.conv = nn.Sequential(
225
            nn.Dropout(0.1),
226
            nn.Conv2d(dw_channels, num_classes, 1)
227
        )
228
229
    def forward(self, x):
230
        x = self.dsconv1(x)
231
        x = self.dsconv2(x)
232
        x = self.conv(x)
233
        return x
234
235
236
def get_fast_scnn(n_classes, **kwargs):
237
238
    model = FastSCNN(n_classes, **kwargs)
239
240
    return model