Switch to unified view

a b/mmseg/models/backbones/stdc.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from mmcv.cnn import ConvModule
7
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
8
9
from mmseg.ops import resize
10
from ..builder import BACKBONES, build_backbone
11
from .bisenetv1 import AttentionRefinementModule
12
13
14
class STDCModule(BaseModule):
15
    """STDCModule.
16
17
    Args:
18
        in_channels (int): The number of input channels.
19
        out_channels (int): The number of output channels before scaling.
20
        stride (int): The number of stride for the first conv layer.
21
        norm_cfg (dict): Config dict for normalization layer. Default: None.
22
        act_cfg (dict): The activation config for conv layers.
23
        num_convs (int): Numbers of conv layers.
24
        fusion_type (str): Type of fusion operation. Default: 'add'.
25
        init_cfg (dict or list[dict], optional): Initialization config dict.
26
            Default: None.
27
    """
28
29
    def __init__(self,
30
                 in_channels,
31
                 out_channels,
32
                 stride,
33
                 norm_cfg=None,
34
                 act_cfg=None,
35
                 num_convs=4,
36
                 fusion_type='add',
37
                 init_cfg=None):
38
        super(STDCModule, self).__init__(init_cfg=init_cfg)
39
        assert num_convs > 1
40
        assert fusion_type in ['add', 'cat']
41
        self.stride = stride
42
        self.with_downsample = True if self.stride == 2 else False
43
        self.fusion_type = fusion_type
44
45
        self.layers = ModuleList()
46
        conv_0 = ConvModule(
47
            in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
48
49
        if self.with_downsample:
50
            self.downsample = ConvModule(
51
                out_channels // 2,
52
                out_channels // 2,
53
                kernel_size=3,
54
                stride=2,
55
                padding=1,
56
                groups=out_channels // 2,
57
                norm_cfg=norm_cfg,
58
                act_cfg=None)
59
60
            if self.fusion_type == 'add':
61
                self.layers.append(nn.Sequential(conv_0, self.downsample))
62
                self.skip = Sequential(
63
                    ConvModule(
64
                        in_channels,
65
                        in_channels,
66
                        kernel_size=3,
67
                        stride=2,
68
                        padding=1,
69
                        groups=in_channels,
70
                        norm_cfg=norm_cfg,
71
                        act_cfg=None),
72
                    ConvModule(
73
                        in_channels,
74
                        out_channels,
75
                        1,
76
                        norm_cfg=norm_cfg,
77
                        act_cfg=None))
78
            else:
79
                self.layers.append(conv_0)
80
                self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
81
        else:
82
            self.layers.append(conv_0)
83
84
        for i in range(1, num_convs):
85
            out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
86
            self.layers.append(
87
                ConvModule(
88
                    out_channels // 2**i,
89
                    out_channels // out_factor,
90
                    kernel_size=3,
91
                    stride=1,
92
                    padding=1,
93
                    norm_cfg=norm_cfg,
94
                    act_cfg=act_cfg))
95
96
    def forward(self, inputs):
97
        if self.fusion_type == 'add':
98
            out = self.forward_add(inputs)
99
        else:
100
            out = self.forward_cat(inputs)
101
        return out
102
103
    def forward_add(self, inputs):
104
        layer_outputs = []
105
        x = inputs.clone()
106
        for layer in self.layers:
107
            x = layer(x)
108
            layer_outputs.append(x)
109
        if self.with_downsample:
110
            inputs = self.skip(inputs)
111
112
        return torch.cat(layer_outputs, dim=1) + inputs
113
114
    def forward_cat(self, inputs):
115
        x0 = self.layers[0](inputs)
116
        layer_outputs = [x0]
117
        for i, layer in enumerate(self.layers[1:]):
118
            if i == 0:
119
                if self.with_downsample:
120
                    x = layer(self.downsample(x0))
121
                else:
122
                    x = layer(x0)
123
            else:
124
                x = layer(x)
125
            layer_outputs.append(x)
126
        if self.with_downsample:
127
            layer_outputs[0] = self.skip(x0)
128
        return torch.cat(layer_outputs, dim=1)
129
130
131
class FeatureFusionModule(BaseModule):
132
    """Feature Fusion Module. This module is different from FeatureFusionModule
133
    in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
134
    channel number is calculated by given `scale_factor`, while
135
    FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
136
    `self.conv_atten`.
137
138
    Args:
139
        in_channels (int): The number of input channels.
140
        out_channels (int): The number of output channels.
141
        scale_factor (int): The number of channel scale factor.
142
            Default: 4.
143
        norm_cfg (dict): Config dict for normalization layer.
144
            Default: dict(type='BN').
145
        act_cfg (dict): The activation config for conv layers.
146
            Default: dict(type='ReLU').
147
        init_cfg (dict or list[dict], optional): Initialization config dict.
148
            Default: None.
149
    """
150
151
    def __init__(self,
152
                 in_channels,
153
                 out_channels,
154
                 scale_factor=4,
155
                 norm_cfg=dict(type='BN'),
156
                 act_cfg=dict(type='ReLU'),
157
                 init_cfg=None):
158
        super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
159
        channels = out_channels // scale_factor
160
        self.conv0 = ConvModule(
161
            in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
162
        self.attention = nn.Sequential(
163
            nn.AdaptiveAvgPool2d((1, 1)),
164
            ConvModule(
165
                out_channels,
166
                channels,
167
                1,
168
                norm_cfg=None,
169
                bias=False,
170
                act_cfg=act_cfg),
171
            ConvModule(
172
                channels,
173
                out_channels,
174
                1,
175
                norm_cfg=None,
176
                bias=False,
177
                act_cfg=None), nn.Sigmoid())
178
179
    def forward(self, spatial_inputs, context_inputs):
180
        inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
181
        x = self.conv0(inputs)
182
        attn = self.attention(x)
183
        x_attn = x * attn
184
        return x_attn + x
185
186
187
@BACKBONES.register_module()
188
class STDCNet(BaseModule):
189
    """This backbone is the implementation of `Rethinking BiSeNet For Real-time
190
    Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
191
192
    Args:
193
        stdc_type (int): The type of backbone structure,
194
            `STDCNet1` and`STDCNet2` denotes two main backbones in paper,
195
            whose FLOPs is 813M and 1446M, respectively.
196
        in_channels (int): The num of input_channels.
197
        channels (tuple[int]): The output channels for each stage.
198
        bottleneck_type (str): The type of STDC Module type, the value must
199
            be 'add' or 'cat'.
200
        norm_cfg (dict): Config dict for normalization layer.
201
        act_cfg (dict): The activation config for conv layers.
202
        num_convs (int): Numbers of conv layer at each STDC Module.
203
            Default: 4.
204
        with_final_conv (bool): Whether add a conv layer at the Module output.
205
            Default: True.
206
        pretrained (str, optional): Model pretrained path. Default: None.
207
        init_cfg (dict or list[dict], optional): Initialization config dict.
208
            Default: None.
209
210
    Example:
211
        >>> import torch
212
        >>> stdc_type = 'STDCNet1'
213
        >>> in_channels = 3
214
        >>> channels = (32, 64, 256, 512, 1024)
215
        >>> bottleneck_type = 'cat'
216
        >>> inputs = torch.rand(1, 3, 1024, 2048)
217
        >>> self = STDCNet(stdc_type, in_channels,
218
        ...                 channels, bottleneck_type).eval()
219
        >>> outputs = self.forward(inputs)
220
        >>> for i in range(len(outputs)):
221
        ...     print(f'outputs[{i}].shape = {outputs[i].shape}')
222
        outputs[0].shape = torch.Size([1, 256, 128, 256])
223
        outputs[1].shape = torch.Size([1, 512, 64, 128])
224
        outputs[2].shape = torch.Size([1, 1024, 32, 64])
225
    """
226
227
    arch_settings = {
228
        'STDCNet1': [(2, 1), (2, 1), (2, 1)],
229
        'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
230
    }
231
232
    def __init__(self,
233
                 stdc_type,
234
                 in_channels,
235
                 channels,
236
                 bottleneck_type,
237
                 norm_cfg,
238
                 act_cfg,
239
                 num_convs=4,
240
                 with_final_conv=False,
241
                 pretrained=None,
242
                 init_cfg=None):
243
        super(STDCNet, self).__init__(init_cfg=init_cfg)
244
        assert stdc_type in self.arch_settings, \
245
            f'invalid structure {stdc_type} for STDCNet.'
246
        assert bottleneck_type in ['add', 'cat'],\
247
            f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
248
249
        assert len(channels) == 5,\
250
            f'invalid channels length {len(channels)} for STDCNet.'
251
252
        self.in_channels = in_channels
253
        self.channels = channels
254
        self.stage_strides = self.arch_settings[stdc_type]
255
        self.prtrained = pretrained
256
        self.num_convs = num_convs
257
        self.with_final_conv = with_final_conv
258
259
        self.stages = ModuleList([
260
            ConvModule(
261
                self.in_channels,
262
                self.channels[0],
263
                kernel_size=3,
264
                stride=2,
265
                padding=1,
266
                norm_cfg=norm_cfg,
267
                act_cfg=act_cfg),
268
            ConvModule(
269
                self.channels[0],
270
                self.channels[1],
271
                kernel_size=3,
272
                stride=2,
273
                padding=1,
274
                norm_cfg=norm_cfg,
275
                act_cfg=act_cfg)
276
        ])
277
        # `self.num_shallow_features` is the number of shallow modules in
278
        # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
279
        # They are both not used for following modules like Attention
280
        # Refinement Module and Feature Fusion Module.
281
        # Thus they would be cut from `outs`. Please refer to Figure 4
282
        # of original paper for more details.
283
        self.num_shallow_features = len(self.stages)
284
285
        for strides in self.stage_strides:
286
            idx = len(self.stages) - 1
287
            self.stages.append(
288
                self._make_stage(self.channels[idx], self.channels[idx + 1],
289
                                 strides, norm_cfg, act_cfg, bottleneck_type))
290
        # After appending, `self.stages` is a ModuleList including several
291
        # shallow modules and STDCModules.
292
        # (len(self.stages) ==
293
        # self.num_shallow_features + len(self.stage_strides))
294
        if self.with_final_conv:
295
            self.final_conv = ConvModule(
296
                self.channels[-1],
297
                max(1024, self.channels[-1]),
298
                1,
299
                norm_cfg=norm_cfg,
300
                act_cfg=act_cfg)
301
302
    def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
303
                    act_cfg, bottleneck_type):
304
        layers = []
305
        for i, stride in enumerate(strides):
306
            layers.append(
307
                STDCModule(
308
                    in_channels if i == 0 else out_channels,
309
                    out_channels,
310
                    stride,
311
                    norm_cfg,
312
                    act_cfg,
313
                    num_convs=self.num_convs,
314
                    fusion_type=bottleneck_type))
315
        return Sequential(*layers)
316
317
    def forward(self, x):
318
        outs = []
319
        for stage in self.stages:
320
            x = stage(x)
321
            outs.append(x)
322
        if self.with_final_conv:
323
            outs[-1] = self.final_conv(outs[-1])
324
        outs = outs[self.num_shallow_features:]
325
        return tuple(outs)
326
327
328
@BACKBONES.register_module()
329
class STDCContextPathNet(BaseModule):
330
    """STDCNet with Context Path. The `outs` below is a list of three feature
331
    maps from deep to shallow, whose height and width is from small to big,
332
    respectively. The biggest feature map of `outs` is outputted for
333
    `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
334
    The other two feature maps are used for Attention Refinement Module,
335
    respectively. Besides, the biggest feature map of `outs` and the last
336
    output of Attention Refinement Module are concatenated for Feature Fusion
337
    Module. Then, this fusion feature map `feat_fuse` would be outputted for
338
    `decode_head`. More details please refer to Figure 4 of original paper.
339
340
    Args:
341
        backbone_cfg (dict): Config dict for stdc backbone.
342
        last_in_channels (tuple(int)), The number of channels of last
343
            two feature maps from stdc backbone. Default: (1024, 512).
344
        out_channels (int): The channels of output feature maps.
345
            Default: 128.
346
        ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
347
            `dict(in_channels=512, out_channels=256, scale_factor=4)`.
348
        upsample_mode (str): Algorithm used for upsampling:
349
                ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
350
                ``'trilinear'``. Default: ``'nearest'``.
351
        align_corners (str): align_corners argument of F.interpolate. It
352
            must be `None` if upsample_mode is ``'nearest'``. Default: None.
353
        norm_cfg (dict): Config dict for normalization layer.
354
            Default: dict(type='BN').
355
        init_cfg (dict or list[dict], optional): Initialization config dict.
356
            Default: None.
357
358
    Return:
359
        outputs (tuple): The tuple of list of output feature map for
360
            auxiliary heads and decoder head.
361
    """
362
363
    def __init__(self,
364
                 backbone_cfg,
365
                 last_in_channels=(1024, 512),
366
                 out_channels=128,
367
                 ffm_cfg=dict(
368
                     in_channels=512, out_channels=256, scale_factor=4),
369
                 upsample_mode='nearest',
370
                 align_corners=None,
371
                 norm_cfg=dict(type='BN'),
372
                 init_cfg=None):
373
        super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
374
        self.backbone = build_backbone(backbone_cfg)
375
        self.arms = ModuleList()
376
        self.convs = ModuleList()
377
        for channels in last_in_channels:
378
            self.arms.append(AttentionRefinementModule(channels, out_channels))
379
            self.convs.append(
380
                ConvModule(
381
                    out_channels,
382
                    out_channels,
383
                    3,
384
                    padding=1,
385
                    norm_cfg=norm_cfg))
386
        self.conv_avg = ConvModule(
387
            last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
388
389
        self.ffm = FeatureFusionModule(**ffm_cfg)
390
391
        self.upsample_mode = upsample_mode
392
        self.align_corners = align_corners
393
394
    def forward(self, x):
395
        outs = list(self.backbone(x))
396
        avg = F.adaptive_avg_pool2d(outs[-1], 1)
397
        avg_feat = self.conv_avg(avg)
398
399
        feature_up = resize(
400
            avg_feat,
401
            size=outs[-1].shape[2:],
402
            mode=self.upsample_mode,
403
            align_corners=self.align_corners)
404
        arms_out = []
405
        for i in range(len(self.arms)):
406
            x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
407
            feature_up = resize(
408
                x_arm,
409
                size=outs[len(outs) - 1 - i - 1].shape[2:],
410
                mode=self.upsample_mode,
411
                align_corners=self.align_corners)
412
            feature_up = self.convs[i](feature_up)
413
            arms_out.append(feature_up)
414
415
        feat_fuse = self.ffm(outs[0], arms_out[1])
416
417
        # The `outputs` has four feature maps.
418
        # `outs[0]` is outputted for `STDCHead` auxiliary head.
419
        # Two feature maps of `arms_out` are outputted for auxiliary head.
420
        # `feat_fuse` is outputted for decoder head.
421
        outputs = [outs[0]] + list(arms_out) + [feat_fuse]
422
        return tuple(outputs)