--- a
+++ b/mmseg/models/backbones/stdc.py
@@ -0,0 +1,422 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
+
+from mmseg.ops import resize
+from ..builder import BACKBONES, build_backbone
+from .bisenetv1 import AttentionRefinementModule
+
+
+class STDCModule(BaseModule):
+    """STDCModule.
+
+    Args:
+        in_channels (int): The number of input channels.
+        out_channels (int): The number of output channels before scaling.
+        stride (int): The number of stride for the first conv layer.
+        norm_cfg (dict): Config dict for normalization layer. Default: None.
+        act_cfg (dict): The activation config for conv layers.
+        num_convs (int): Numbers of conv layers.
+        fusion_type (str): Type of fusion operation. Default: 'add'.
+        init_cfg (dict or list[dict], optional): Initialization config dict.
+            Default: None.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 norm_cfg=None,
+                 act_cfg=None,
+                 num_convs=4,
+                 fusion_type='add',
+                 init_cfg=None):
+        super(STDCModule, self).__init__(init_cfg=init_cfg)
+        assert num_convs > 1
+        assert fusion_type in ['add', 'cat']
+        self.stride = stride
+        self.with_downsample = True if self.stride == 2 else False
+        self.fusion_type = fusion_type
+
+        self.layers = ModuleList()
+        conv_0 = ConvModule(
+            in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
+
+        if self.with_downsample:
+            self.downsample = ConvModule(
+                out_channels // 2,
+                out_channels // 2,
+                kernel_size=3,
+                stride=2,
+                padding=1,
+                groups=out_channels // 2,
+                norm_cfg=norm_cfg,
+                act_cfg=None)
+
+            if self.fusion_type == 'add':
+                self.layers.append(nn.Sequential(conv_0, self.downsample))
+                self.skip = Sequential(
+                    ConvModule(
+                        in_channels,
+                        in_channels,
+                        kernel_size=3,
+                        stride=2,
+                        padding=1,
+                        groups=in_channels,
+                        norm_cfg=norm_cfg,
+                        act_cfg=None),
+                    ConvModule(
+                        in_channels,
+                        out_channels,
+                        1,
+                        norm_cfg=norm_cfg,
+                        act_cfg=None))
+            else:
+                self.layers.append(conv_0)
+                self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+        else:
+            self.layers.append(conv_0)
+
+        for i in range(1, num_convs):
+            out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
+            self.layers.append(
+                ConvModule(
+                    out_channels // 2**i,
+                    out_channels // out_factor,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+
+    def forward(self, inputs):
+        if self.fusion_type == 'add':
+            out = self.forward_add(inputs)
+        else:
+            out = self.forward_cat(inputs)
+        return out
+
+    def forward_add(self, inputs):
+        layer_outputs = []
+        x = inputs.clone()
+        for layer in self.layers:
+            x = layer(x)
+            layer_outputs.append(x)
+        if self.with_downsample:
+            inputs = self.skip(inputs)
+
+        return torch.cat(layer_outputs, dim=1) + inputs
+
+    def forward_cat(self, inputs):
+        x0 = self.layers[0](inputs)
+        layer_outputs = [x0]
+        for i, layer in enumerate(self.layers[1:]):
+            if i == 0:
+                if self.with_downsample:
+                    x = layer(self.downsample(x0))
+                else:
+                    x = layer(x0)
+            else:
+                x = layer(x)
+            layer_outputs.append(x)
+        if self.with_downsample:
+            layer_outputs[0] = self.skip(x0)
+        return torch.cat(layer_outputs, dim=1)
+
+
+class FeatureFusionModule(BaseModule):
+    """Feature Fusion Module. This module is different from FeatureFusionModule
+    in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
+    channel number is calculated by given `scale_factor`, while
+    FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
+    `self.conv_atten`.
+
+    Args:
+        in_channels (int): The number of input channels.
+        out_channels (int): The number of output channels.
+        scale_factor (int): The number of channel scale factor.
+            Default: 4.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict): The activation config for conv layers.
+            Default: dict(type='ReLU').
+        init_cfg (dict or list[dict], optional): Initialization config dict.
+            Default: None.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 scale_factor=4,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 init_cfg=None):
+        super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
+        channels = out_channels // scale_factor
+        self.conv0 = ConvModule(
+            in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
+        self.attention = nn.Sequential(
+            nn.AdaptiveAvgPool2d((1, 1)),
+            ConvModule(
+                out_channels,
+                channels,
+                1,
+                norm_cfg=None,
+                bias=False,
+                act_cfg=act_cfg),
+            ConvModule(
+                channels,
+                out_channels,
+                1,
+                norm_cfg=None,
+                bias=False,
+                act_cfg=None), nn.Sigmoid())
+
+    def forward(self, spatial_inputs, context_inputs):
+        inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
+        x = self.conv0(inputs)
+        attn = self.attention(x)
+        x_attn = x * attn
+        return x_attn + x
+
+
+@BACKBONES.register_module()
+class STDCNet(BaseModule):
+    """This backbone is the implementation of `Rethinking BiSeNet For Real-time
+    Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
+
+    Args:
+        stdc_type (int): The type of backbone structure,
+            `STDCNet1` and`STDCNet2` denotes two main backbones in paper,
+            whose FLOPs is 813M and 1446M, respectively.
+        in_channels (int): The num of input_channels.
+        channels (tuple[int]): The output channels for each stage.
+        bottleneck_type (str): The type of STDC Module type, the value must
+            be 'add' or 'cat'.
+        norm_cfg (dict): Config dict for normalization layer.
+        act_cfg (dict): The activation config for conv layers.
+        num_convs (int): Numbers of conv layer at each STDC Module.
+            Default: 4.
+        with_final_conv (bool): Whether add a conv layer at the Module output.
+            Default: True.
+        pretrained (str, optional): Model pretrained path. Default: None.
+        init_cfg (dict or list[dict], optional): Initialization config dict.
+            Default: None.
+
+    Example:
+        >>> import torch
+        >>> stdc_type = 'STDCNet1'
+        >>> in_channels = 3
+        >>> channels = (32, 64, 256, 512, 1024)
+        >>> bottleneck_type = 'cat'
+        >>> inputs = torch.rand(1, 3, 1024, 2048)
+        >>> self = STDCNet(stdc_type, in_channels,
+        ...                 channels, bottleneck_type).eval()
+        >>> outputs = self.forward(inputs)
+        >>> for i in range(len(outputs)):
+        ...     print(f'outputs[{i}].shape = {outputs[i].shape}')
+        outputs[0].shape = torch.Size([1, 256, 128, 256])
+        outputs[1].shape = torch.Size([1, 512, 64, 128])
+        outputs[2].shape = torch.Size([1, 1024, 32, 64])
+    """
+
+    arch_settings = {
+        'STDCNet1': [(2, 1), (2, 1), (2, 1)],
+        'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
+    }
+
+    def __init__(self,
+                 stdc_type,
+                 in_channels,
+                 channels,
+                 bottleneck_type,
+                 norm_cfg,
+                 act_cfg,
+                 num_convs=4,
+                 with_final_conv=False,
+                 pretrained=None,
+                 init_cfg=None):
+        super(STDCNet, self).__init__(init_cfg=init_cfg)
+        assert stdc_type in self.arch_settings, \
+            f'invalid structure {stdc_type} for STDCNet.'
+        assert bottleneck_type in ['add', 'cat'],\
+            f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
+
+        assert len(channels) == 5,\
+            f'invalid channels length {len(channels)} for STDCNet.'
+
+        self.in_channels = in_channels
+        self.channels = channels
+        self.stage_strides = self.arch_settings[stdc_type]
+        self.prtrained = pretrained
+        self.num_convs = num_convs
+        self.with_final_conv = with_final_conv
+
+        self.stages = ModuleList([
+            ConvModule(
+                self.in_channels,
+                self.channels[0],
+                kernel_size=3,
+                stride=2,
+                padding=1,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg),
+            ConvModule(
+                self.channels[0],
+                self.channels[1],
+                kernel_size=3,
+                stride=2,
+                padding=1,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+        ])
+        # `self.num_shallow_features` is the number of shallow modules in
+        # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
+        # They are both not used for following modules like Attention
+        # Refinement Module and Feature Fusion Module.
+        # Thus they would be cut from `outs`. Please refer to Figure 4
+        # of original paper for more details.
+        self.num_shallow_features = len(self.stages)
+
+        for strides in self.stage_strides:
+            idx = len(self.stages) - 1
+            self.stages.append(
+                self._make_stage(self.channels[idx], self.channels[idx + 1],
+                                 strides, norm_cfg, act_cfg, bottleneck_type))
+        # After appending, `self.stages` is a ModuleList including several
+        # shallow modules and STDCModules.
+        # (len(self.stages) ==
+        # self.num_shallow_features + len(self.stage_strides))
+        if self.with_final_conv:
+            self.final_conv = ConvModule(
+                self.channels[-1],
+                max(1024, self.channels[-1]),
+                1,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+
+    def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
+                    act_cfg, bottleneck_type):
+        layers = []
+        for i, stride in enumerate(strides):
+            layers.append(
+                STDCModule(
+                    in_channels if i == 0 else out_channels,
+                    out_channels,
+                    stride,
+                    norm_cfg,
+                    act_cfg,
+                    num_convs=self.num_convs,
+                    fusion_type=bottleneck_type))
+        return Sequential(*layers)
+
+    def forward(self, x):
+        outs = []
+        for stage in self.stages:
+            x = stage(x)
+            outs.append(x)
+        if self.with_final_conv:
+            outs[-1] = self.final_conv(outs[-1])
+        outs = outs[self.num_shallow_features:]
+        return tuple(outs)
+
+
+@BACKBONES.register_module()
+class STDCContextPathNet(BaseModule):
+    """STDCNet with Context Path. The `outs` below is a list of three feature
+    maps from deep to shallow, whose height and width is from small to big,
+    respectively. The biggest feature map of `outs` is outputted for
+    `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
+    The other two feature maps are used for Attention Refinement Module,
+    respectively. Besides, the biggest feature map of `outs` and the last
+    output of Attention Refinement Module are concatenated for Feature Fusion
+    Module. Then, this fusion feature map `feat_fuse` would be outputted for
+    `decode_head`. More details please refer to Figure 4 of original paper.
+
+    Args:
+        backbone_cfg (dict): Config dict for stdc backbone.
+        last_in_channels (tuple(int)), The number of channels of last
+            two feature maps from stdc backbone. Default: (1024, 512).
+        out_channels (int): The channels of output feature maps.
+            Default: 128.
+        ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
+            `dict(in_channels=512, out_channels=256, scale_factor=4)`.
+        upsample_mode (str): Algorithm used for upsampling:
+                ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+                ``'trilinear'``. Default: ``'nearest'``.
+        align_corners (str): align_corners argument of F.interpolate. It
+            must be `None` if upsample_mode is ``'nearest'``. Default: None.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN').
+        init_cfg (dict or list[dict], optional): Initialization config dict.
+            Default: None.
+
+    Return:
+        outputs (tuple): The tuple of list of output feature map for
+            auxiliary heads and decoder head.
+    """
+
+    def __init__(self,
+                 backbone_cfg,
+                 last_in_channels=(1024, 512),
+                 out_channels=128,
+                 ffm_cfg=dict(
+                     in_channels=512, out_channels=256, scale_factor=4),
+                 upsample_mode='nearest',
+                 align_corners=None,
+                 norm_cfg=dict(type='BN'),
+                 init_cfg=None):
+        super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
+        self.backbone = build_backbone(backbone_cfg)
+        self.arms = ModuleList()
+        self.convs = ModuleList()
+        for channels in last_in_channels:
+            self.arms.append(AttentionRefinementModule(channels, out_channels))
+            self.convs.append(
+                ConvModule(
+                    out_channels,
+                    out_channels,
+                    3,
+                    padding=1,
+                    norm_cfg=norm_cfg))
+        self.conv_avg = ConvModule(
+            last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
+
+        self.ffm = FeatureFusionModule(**ffm_cfg)
+
+        self.upsample_mode = upsample_mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        outs = list(self.backbone(x))
+        avg = F.adaptive_avg_pool2d(outs[-1], 1)
+        avg_feat = self.conv_avg(avg)
+
+        feature_up = resize(
+            avg_feat,
+            size=outs[-1].shape[2:],
+            mode=self.upsample_mode,
+            align_corners=self.align_corners)
+        arms_out = []
+        for i in range(len(self.arms)):
+            x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
+            feature_up = resize(
+                x_arm,
+                size=outs[len(outs) - 1 - i - 1].shape[2:],
+                mode=self.upsample_mode,
+                align_corners=self.align_corners)
+            feature_up = self.convs[i](feature_up)
+            arms_out.append(feature_up)
+
+        feat_fuse = self.ffm(outs[0], arms_out[1])
+
+        # The `outputs` has four feature maps.
+        # `outs[0]` is outputted for `STDCHead` auxiliary head.
+        # Two feature maps of `arms_out` are outputted for auxiliary head.
+        # `feat_fuse` is outputted for decoder head.
+        outputs = [outs[0]] + list(arms_out) + [feat_fuse]
+        return tuple(outputs)