Switch to side-by-side view

--- a
+++ b/tests/test_models/test_backbones.py
@@ -0,0 +1,930 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import pytest
+import torch
+import torch.nn as nn
+from mmcv.utils import _BatchNorm
+
+from mmaction.models import (C3D, STGCN, X3D, MobileNetV2TSM, ResNet2Plus1d,
+                             ResNet3dCSN, ResNet3dSlowFast, ResNet3dSlowOnly,
+                             ResNetAudio, ResNetTIN, ResNetTSM, TANet,
+                             TimeSformer)
+from mmaction.models.backbones.resnet_tsm import NL3DWrapper
+from .base import check_norm_state, generate_backbone_demo_inputs
+
+
+def test_x3d_backbone():
+    """Test x3d backbone."""
+    with pytest.raises(AssertionError):
+        # In X3D: 1 <= num_stages <= 4
+        X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, num_stages=0)
+
+    with pytest.raises(AssertionError):
+        # In X3D: 1 <= num_stages <= 4
+        X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, num_stages=5)
+
+    with pytest.raises(AssertionError):
+        # len(spatial_strides) == num_stages
+        X3D(gamma_w=1.0,
+            gamma_b=2.25,
+            gamma_d=2.2,
+            spatial_strides=(1, 2),
+            num_stages=4)
+
+    with pytest.raises(AssertionError):
+        # se_style in ['half', 'all']
+        X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, se_style=None)
+
+    with pytest.raises(AssertionError):
+        # se_ratio should be None or > 0
+        X3D(gamma_w=1.0,
+            gamma_b=2.25,
+            gamma_d=2.2,
+            se_style='half',
+            se_ratio=0)
+
+    # x3d_s, no pretrained, norm_eval True
+    x3d_s = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, norm_eval=True)
+    x3d_s.init_weights()
+    x3d_s.train()
+    assert check_norm_state(x3d_s.modules(), False)
+
+    # x3d_l, no pretrained, norm_eval True
+    x3d_l = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=5.0, norm_eval=True)
+    x3d_l.init_weights()
+    x3d_l.train()
+    assert check_norm_state(x3d_l.modules(), False)
+
+    # x3d_s, no pretrained, norm_eval False
+    x3d_s = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, norm_eval=False)
+    x3d_s.init_weights()
+    x3d_s.train()
+    assert check_norm_state(x3d_s.modules(), True)
+
+    # x3d_l, no pretrained, norm_eval False
+    x3d_l = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=5.0, norm_eval=False)
+    x3d_l.init_weights()
+    x3d_l.train()
+    assert check_norm_state(x3d_l.modules(), True)
+
+    # x3d_s, no pretrained, frozen_stages, norm_eval False
+    frozen_stages = 1
+    x3d_s_frozen = X3D(
+        gamma_w=1.0,
+        gamma_b=2.25,
+        gamma_d=2.2,
+        norm_eval=False,
+        frozen_stages=frozen_stages)
+
+    x3d_s_frozen.init_weights()
+    x3d_s_frozen.train()
+    assert x3d_s_frozen.conv1_t.bn.training is False
+    for param in x3d_s_frozen.conv1_s.parameters():
+        assert param.requires_grad is False
+    for param in x3d_s_frozen.conv1_t.parameters():
+        assert param.requires_grad is False
+
+    for i in range(1, frozen_stages + 1):
+        layer = getattr(x3d_s_frozen, f'layer{i}')
+        for mod in layer.modules():
+            if isinstance(mod, _BatchNorm):
+                assert mod.training is False
+        for param in layer.parameters():
+            assert param.requires_grad is False
+
+    # test zero_init_residual, zero_init_residual is True by default
+    for m in x3d_s_frozen.modules():
+        if hasattr(m, 'conv3'):
+            assert torch.equal(m.conv3.bn.weight,
+                               torch.zeros_like(m.conv3.bn.weight))
+            assert torch.equal(m.conv3.bn.bias,
+                               torch.zeros_like(m.conv3.bn.bias))
+
+    # x3d_s inference
+    input_shape = (1, 3, 13, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    # parrots 3dconv is only implemented on gpu
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            x3d_s_frozen = x3d_s_frozen.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = x3d_s_frozen(imgs_gpu)
+            assert feat.shape == torch.Size([1, 432, 13, 2, 2])
+    else:
+        feat = x3d_s_frozen(imgs)
+        assert feat.shape == torch.Size([1, 432, 13, 2, 2])
+
+    # x3d_m inference
+    input_shape = (1, 3, 16, 96, 96)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    # parrots 3dconv is only implemented on gpu
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            x3d_s_frozen = x3d_s_frozen.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = x3d_s_frozen(imgs_gpu)
+            assert feat.shape == torch.Size([1, 432, 16, 3, 3])
+    else:
+        feat = x3d_s_frozen(imgs)
+        assert feat.shape == torch.Size([1, 432, 16, 3, 3])
+
+
+def test_resnet2plus1d_backbone():
+    # Test r2+1d backbone
+    with pytest.raises(AssertionError):
+        # r2+1d does not support inflation
+        ResNet2Plus1d(50, None, pretrained2d=True)
+
+    with pytest.raises(AssertionError):
+        # r2+1d requires conv(2+1)d module
+        ResNet2Plus1d(
+            50, None, pretrained2d=False, conv_cfg=dict(type='Conv3d'))
+
+    frozen_stages = 1
+    r2plus1d_34_frozen = ResNet2Plus1d(
+        34,
+        None,
+        conv_cfg=dict(type='Conv2plus1d'),
+        pretrained2d=False,
+        frozen_stages=frozen_stages,
+        conv1_kernel=(3, 7, 7),
+        conv1_stride_t=1,
+        pool1_stride_t=1,
+        inflate=(1, 1, 1, 1),
+        spatial_strides=(1, 2, 2, 2),
+        temporal_strides=(1, 2, 2, 2))
+    r2plus1d_34_frozen.init_weights()
+    r2plus1d_34_frozen.train()
+    assert r2plus1d_34_frozen.conv1.conv.bn_s.training is False
+    assert r2plus1d_34_frozen.conv1.bn.training is False
+    for param in r2plus1d_34_frozen.conv1.parameters():
+        assert param.requires_grad is False
+    for i in range(1, frozen_stages + 1):
+        layer = getattr(r2plus1d_34_frozen, f'layer{i}')
+        for mod in layer.modules():
+            if isinstance(mod, _BatchNorm):
+                assert mod.training is False
+        for param in layer.parameters():
+            assert param.requires_grad is False
+    input_shape = (1, 3, 8, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    # parrots 3dconv is only implemented on gpu
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            r2plus1d_34_frozen = r2plus1d_34_frozen.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = r2plus1d_34_frozen(imgs_gpu)
+            assert feat.shape == torch.Size([1, 512, 1, 2, 2])
+    else:
+        feat = r2plus1d_34_frozen(imgs)
+        assert feat.shape == torch.Size([1, 512, 1, 2, 2])
+
+    r2plus1d_50_frozen = ResNet2Plus1d(
+        50,
+        None,
+        conv_cfg=dict(type='Conv2plus1d'),
+        pretrained2d=False,
+        conv1_kernel=(3, 7, 7),
+        conv1_stride_t=1,
+        pool1_stride_t=1,
+        inflate=(1, 1, 1, 1),
+        spatial_strides=(1, 2, 2, 2),
+        temporal_strides=(1, 2, 2, 2),
+        frozen_stages=frozen_stages)
+    r2plus1d_50_frozen.init_weights()
+
+    r2plus1d_50_frozen.train()
+    assert r2plus1d_50_frozen.conv1.conv.bn_s.training is False
+    assert r2plus1d_50_frozen.conv1.bn.training is False
+    for param in r2plus1d_50_frozen.conv1.parameters():
+        assert param.requires_grad is False
+    for i in range(1, frozen_stages + 1):
+        layer = getattr(r2plus1d_50_frozen, f'layer{i}')
+        for mod in layer.modules():
+            if isinstance(mod, _BatchNorm):
+                assert mod.training is False
+        for param in layer.parameters():
+            assert param.requires_grad is False
+    input_shape = (1, 3, 8, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+
+    # parrots 3dconv is only implemented on gpu
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            r2plus1d_50_frozen = r2plus1d_50_frozen.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = r2plus1d_50_frozen(imgs_gpu)
+            assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
+    else:
+        feat = r2plus1d_50_frozen(imgs)
+        assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
+
+
+def test_resnet_tsm_backbone():
+    """Test resnet_tsm backbone."""
+    with pytest.raises(NotImplementedError):
+        # shift_place must be block or blockres
+        resnet_tsm_50_block = ResNetTSM(50, shift_place='Block')
+        resnet_tsm_50_block.init_weights()
+
+    from mmaction.models.backbones.resnet import Bottleneck
+    from mmaction.models.backbones.resnet_tsm import TemporalShift
+
+    input_shape = (8, 3, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+
+    # resnet_tsm with depth 50
+    resnet_tsm_50 = ResNetTSM(50)
+    resnet_tsm_50.init_weights()
+    for layer_name in resnet_tsm_50.res_layers:
+        layer = getattr(resnet_tsm_50, layer_name)
+        blocks = list(layer.children())
+        for block in blocks:
+            assert isinstance(block.conv1.conv, TemporalShift)
+            assert block.conv1.conv.num_segments == resnet_tsm_50.num_segments
+            assert block.conv1.conv.shift_div == resnet_tsm_50.shift_div
+            assert isinstance(block.conv1.conv.net, nn.Conv2d)
+
+    # resnet_tsm with depth 50, no pretrained, shift_place is block
+    resnet_tsm_50_block = ResNetTSM(50, shift_place='block')
+    resnet_tsm_50_block.init_weights()
+    for layer_name in resnet_tsm_50_block.res_layers:
+        layer = getattr(resnet_tsm_50_block, layer_name)
+        blocks = list(layer.children())
+        for block in blocks:
+            assert isinstance(block, TemporalShift)
+            assert block.num_segments == resnet_tsm_50_block.num_segments
+            assert block.num_segments == resnet_tsm_50_block.num_segments
+            assert block.shift_div == resnet_tsm_50_block.shift_div
+            assert isinstance(block.net, Bottleneck)
+
+    # resnet_tsm with depth 50, no pretrained, use temporal_pool
+    resnet_tsm_50_temporal_pool = ResNetTSM(50, temporal_pool=True)
+    resnet_tsm_50_temporal_pool.init_weights()
+    for layer_name in resnet_tsm_50_temporal_pool.res_layers:
+        layer = getattr(resnet_tsm_50_temporal_pool, layer_name)
+        blocks = list(layer.children())
+
+        if layer_name == 'layer2':
+            assert len(blocks) == 2
+            assert isinstance(blocks[1], nn.MaxPool3d)
+            blocks = copy.deepcopy(blocks[0])
+
+        for block in blocks:
+            assert isinstance(block.conv1.conv, TemporalShift)
+            if layer_name == 'layer1':
+                assert block.conv1.conv.num_segments == \
+                       resnet_tsm_50_temporal_pool.num_segments
+            else:
+                assert block.conv1.conv.num_segments == \
+                       resnet_tsm_50_temporal_pool.num_segments // 2
+            assert block.conv1.conv.shift_div == resnet_tsm_50_temporal_pool.shift_div  # noqa: E501
+            assert isinstance(block.conv1.conv.net, nn.Conv2d)
+
+    # resnet_tsm with non-local module
+    non_local_cfg = dict(
+        sub_sample=True,
+        use_scale=False,
+        norm_cfg=dict(type='BN3d', requires_grad=True),
+        mode='embedded_gaussian')
+    non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
+    resnet_tsm_nonlocal = ResNetTSM(
+        50, non_local=non_local, non_local_cfg=non_local_cfg)
+    resnet_tsm_nonlocal.init_weights()
+    for layer_name in ['layer2', 'layer3']:
+        layer = getattr(resnet_tsm_nonlocal, layer_name)
+        for i, _ in enumerate(layer):
+            if i % 2 == 0:
+                assert isinstance(layer[i], NL3DWrapper)
+
+    resnet_tsm_50_full = ResNetTSM(
+        50,
+        non_local=non_local,
+        non_local_cfg=non_local_cfg,
+        temporal_pool=True)
+    resnet_tsm_50_full.init_weights()
+
+    # TSM forword
+    feat = resnet_tsm_50(imgs)
+    assert feat.shape == torch.Size([8, 2048, 2, 2])
+
+    # TSM with non-local forward
+    feat = resnet_tsm_nonlocal(imgs)
+    assert feat.shape == torch.Size([8, 2048, 2, 2])
+
+    # TSM with temporal pool forward
+    feat = resnet_tsm_50_temporal_pool(imgs)
+    assert feat.shape == torch.Size([4, 2048, 2, 2])
+
+    # TSM with temporal pool + non-local forward
+    input_shape = (16, 3, 32, 32)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    feat = resnet_tsm_50_full(imgs)
+    assert feat.shape == torch.Size([8, 2048, 1, 1])
+
+
+def test_mobilenetv2_tsm_backbone():
+    """Test mobilenetv2_tsm backbone."""
+    from mmaction.models.backbones.resnet_tsm import TemporalShift
+    from mmaction.models.backbones.mobilenet_v2 import InvertedResidual
+    from mmcv.cnn import ConvModule
+
+    input_shape = (8, 3, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+
+    # mobilenetv2_tsm with width_mult = 1.0
+    mobilenetv2_tsm = MobileNetV2TSM()
+    mobilenetv2_tsm.init_weights()
+    for cur_module in mobilenetv2_tsm.modules():
+        if isinstance(cur_module, InvertedResidual) and \
+            len(cur_module.conv) == 3 and \
+                cur_module.use_res_connect:
+            assert isinstance(cur_module.conv[0], TemporalShift)
+            assert cur_module.conv[0].num_segments == \
+                mobilenetv2_tsm.num_segments
+            assert cur_module.conv[0].shift_div == mobilenetv2_tsm.shift_div
+            assert isinstance(cur_module.conv[0].net, ConvModule)
+
+    # TSM-MobileNetV2 with widen_factor = 1.0 forword
+    feat = mobilenetv2_tsm(imgs)
+    assert feat.shape == torch.Size([8, 1280, 2, 2])
+
+    # mobilenetv2 with widen_factor = 0.5 forword
+    mobilenetv2_tsm_05 = MobileNetV2TSM(widen_factor=0.5)
+    mobilenetv2_tsm_05.init_weights()
+    feat = mobilenetv2_tsm_05(imgs)
+    assert feat.shape == torch.Size([8, 1280, 2, 2])
+
+    # mobilenetv2 with widen_factor = 1.5 forword
+    mobilenetv2_tsm_15 = MobileNetV2TSM(widen_factor=1.5)
+    mobilenetv2_tsm_15.init_weights()
+    feat = mobilenetv2_tsm_15(imgs)
+    assert feat.shape == torch.Size([8, 1920, 2, 2])
+
+
+def test_slowfast_backbone():
+    """Test SlowFast backbone."""
+    with pytest.raises(TypeError):
+        # cfg should be a dict
+        ResNet3dSlowFast(None, slow_pathway=list(['foo', 'bar']))
+    with pytest.raises(TypeError):
+        # pretrained should be a str
+        sf_50 = ResNet3dSlowFast(dict(foo='bar'))
+        sf_50.init_weights()
+    with pytest.raises(KeyError):
+        # pathway type should be implemented
+        ResNet3dSlowFast(None, slow_pathway=dict(type='resnext'))
+
+    # test slowfast with slow inflated
+    sf_50_inflate = ResNet3dSlowFast(
+        None,
+        slow_pathway=dict(
+            type='resnet3d',
+            depth=50,
+            pretrained='torchvision://resnet50',
+            pretrained2d=True,
+            lateral=True,
+            conv1_kernel=(1, 7, 7),
+            dilations=(1, 1, 1, 1),
+            conv1_stride_t=1,
+            pool1_stride_t=1,
+            inflate=(0, 0, 1, 1)))
+    sf_50_inflate.init_weights()
+    sf_50_inflate.train()
+
+    # test slowfast with no lateral connection
+    sf_50_wo_lateral = ResNet3dSlowFast(
+        None,
+        slow_pathway=dict(
+            type='resnet3d',
+            depth=50,
+            pretrained=None,
+            lateral=False,
+            conv1_kernel=(1, 7, 7),
+            dilations=(1, 1, 1, 1),
+            conv1_stride_t=1,
+            pool1_stride_t=1,
+            inflate=(0, 0, 1, 1)))
+    sf_50_wo_lateral.init_weights()
+    sf_50_wo_lateral.train()
+
+    # slowfast w/o lateral connection inference test
+    input_shape = (1, 3, 8, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    # parrots 3dconv is only implemented on gpu
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            sf_50_wo_lateral = sf_50_wo_lateral.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = sf_50_wo_lateral(imgs_gpu)
+    else:
+        feat = sf_50_wo_lateral(imgs)
+
+    assert isinstance(feat, tuple)
+    assert feat[0].shape == torch.Size([1, 2048, 1, 2, 2])
+    assert feat[1].shape == torch.Size([1, 256, 8, 2, 2])
+
+    # test slowfast with frozen stages config
+    frozen_slow = 3
+    sf_50 = ResNet3dSlowFast(
+        None,
+        slow_pathway=dict(
+            type='resnet3d',
+            depth=50,
+            pretrained=None,
+            pretrained2d=True,
+            lateral=True,
+            conv1_kernel=(1, 7, 7),
+            dilations=(1, 1, 1, 1),
+            conv1_stride_t=1,
+            pool1_stride_t=1,
+            inflate=(0, 0, 1, 1),
+            frozen_stages=frozen_slow))
+    sf_50.init_weights()
+    sf_50.train()
+
+    for stage in range(1, sf_50.slow_path.num_stages):
+        lateral_name = sf_50.slow_path.lateral_connections[stage - 1]
+        conv_lateral = getattr(sf_50.slow_path, lateral_name)
+        for mod in conv_lateral.modules():
+            if isinstance(mod, _BatchNorm):
+                if stage <= frozen_slow:
+                    assert mod.training is False
+                else:
+                    assert mod.training is True
+        for param in conv_lateral.parameters():
+            if stage <= frozen_slow:
+                assert param.requires_grad is False
+            else:
+                assert param.requires_grad is True
+
+    # test slowfast with normal config
+    sf_50 = ResNet3dSlowFast(None)
+    sf_50.init_weights()
+    sf_50.train()
+
+    # slowfast inference test
+    input_shape = (1, 3, 8, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    # parrots 3dconv is only implemented on gpu
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            sf_50 = sf_50.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = sf_50(imgs_gpu)
+    else:
+        feat = sf_50(imgs)
+
+    assert isinstance(feat, tuple)
+    assert feat[0].shape == torch.Size([1, 2048, 1, 2, 2])
+    assert feat[1].shape == torch.Size([1, 256, 8, 2, 2])
+
+
+def test_slowonly_backbone():
+    """Test SlowOnly backbone."""
+    with pytest.raises(AssertionError):
+        # SlowOnly should contain no lateral connection
+        ResNet3dSlowOnly(50, None, lateral=True)
+
+    # test SlowOnly for PoseC3D
+    so_50 = ResNet3dSlowOnly(
+        depth=50,
+        pretrained=None,
+        in_channels=17,
+        base_channels=32,
+        num_stages=3,
+        out_indices=(2, ),
+        stage_blocks=(4, 6, 3),
+        conv1_stride_s=1,
+        pool1_stride_s=1,
+        inflate=(0, 1, 1),
+        spatial_strides=(2, 2, 2),
+        temporal_strides=(1, 1, 2),
+        dilations=(1, 1, 1))
+    so_50.init_weights()
+    so_50.train()
+
+    # test SlowOnly with normal config
+    so_50 = ResNet3dSlowOnly(50, None)
+    so_50.init_weights()
+    so_50.train()
+
+    # SlowOnly inference test
+    input_shape = (1, 3, 8, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    # parrots 3dconv is only implemented on gpu
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            so_50 = so_50.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = so_50(imgs_gpu)
+    else:
+        feat = so_50(imgs)
+    assert feat.shape == torch.Size([1, 2048, 8, 2, 2])
+
+
+def test_resnet_csn_backbone():
+    """Test resnet_csn backbone."""
+    with pytest.raises(ValueError):
+        # Bottleneck mode must be "ip" or "ir"
+        ResNet3dCSN(152, None, bottleneck_mode='id')
+
+    input_shape = (2, 3, 6, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+
+    resnet3d_csn_frozen = ResNet3dCSN(
+        152, None, bn_frozen=True, norm_eval=True)
+    resnet3d_csn_frozen.train()
+    for m in resnet3d_csn_frozen.modules():
+        if isinstance(m, _BatchNorm):
+            for param in m.parameters():
+                assert param.requires_grad is False
+
+    # Interaction-preserved channel-separated bottleneck block
+    resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
+    resnet3d_csn_ip.init_weights()
+    resnet3d_csn_ip.train()
+    for i, layer_name in enumerate(resnet3d_csn_ip.res_layers):
+        layers = getattr(resnet3d_csn_ip, layer_name)
+        num_blocks = resnet3d_csn_ip.stage_blocks[i]
+        assert len(layers) == num_blocks
+        for layer in layers:
+            assert isinstance(layer.conv2, nn.Sequential)
+            assert len(layer.conv2) == 2
+            assert layer.conv2[1].groups == layer.planes
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            resnet3d_csn_ip = resnet3d_csn_ip.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = resnet3d_csn_ip(imgs_gpu)
+            assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
+    else:
+        feat = resnet3d_csn_ip(imgs)
+        assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
+
+    # Interaction-reduced channel-separated bottleneck block
+    resnet3d_csn_ir = ResNet3dCSN(152, None, bottleneck_mode='ir')
+    resnet3d_csn_ir.init_weights()
+    resnet3d_csn_ir.train()
+    for i, layer_name in enumerate(resnet3d_csn_ir.res_layers):
+        layers = getattr(resnet3d_csn_ir, layer_name)
+        num_blocks = resnet3d_csn_ir.stage_blocks[i]
+        assert len(layers) == num_blocks
+        for layer in layers:
+            assert isinstance(layer.conv2, nn.Sequential)
+            assert len(layer.conv2) == 1
+            assert layer.conv2[0].groups == layer.planes
+    if torch.__version__ == 'parrots':
+        if torch.cuda.is_available():
+            resnet3d_csn_ir = resnet3d_csn_ir.cuda()
+            imgs_gpu = imgs.cuda()
+            feat = resnet3d_csn_ir(imgs_gpu)
+            assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
+    else:
+        feat = resnet3d_csn_ir(imgs)
+        assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
+
+    # Set training status = False
+    resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
+    resnet3d_csn_ip.init_weights()
+    resnet3d_csn_ip.train(False)
+    for module in resnet3d_csn_ip.children():
+        assert module.training is False
+
+
+def test_tanet_backbone():
+    """Test tanet backbone."""
+    with pytest.raises(NotImplementedError):
+        # TA-Blocks are only based on Bottleneck block now
+        tanet_18 = TANet(18, 8)
+        tanet_18.init_weights()
+
+    from mmaction.models.backbones.resnet import Bottleneck
+    from mmaction.models.backbones.tanet import TABlock
+
+    # tanet with depth 50
+    tanet_50 = TANet(50, 8)
+    tanet_50.init_weights()
+
+    for layer_name in tanet_50.res_layers:
+        layer = getattr(tanet_50, layer_name)
+        blocks = list(layer.children())
+        for block in blocks:
+            assert isinstance(block, TABlock)
+            assert isinstance(block.block, Bottleneck)
+            assert block.tam.num_segments == block.num_segments
+            assert block.tam.in_channels == block.block.conv1.out_channels
+
+    input_shape = (8, 3, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    feat = tanet_50(imgs)
+    assert feat.shape == torch.Size([8, 2048, 2, 2])
+
+    input_shape = (16, 3, 32, 32)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    feat = tanet_50(imgs)
+    assert feat.shape == torch.Size([16, 2048, 1, 1])
+
+
+def test_timesformer_backbone():
+    input_shape = (1, 3, 8, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+
+    # divided_space_time
+    timesformer = TimeSformer(
+        8, 64, 16, embed_dims=768, attention_type='divided_space_time')
+    timesformer.init_weights()
+    from mmaction.models.common import (DividedSpatialAttentionWithNorm,
+                                        DividedTemporalAttentionWithNorm,
+                                        FFNWithNorm)
+    assert isinstance(timesformer.transformer_layers.layers[0].attentions[0],
+                      DividedTemporalAttentionWithNorm)
+    assert isinstance(timesformer.transformer_layers.layers[11].attentions[1],
+                      DividedSpatialAttentionWithNorm)
+    assert isinstance(timesformer.transformer_layers.layers[0].ffns[0],
+                      FFNWithNorm)
+    assert hasattr(timesformer, 'time_embed')
+    assert timesformer.patch_embed.num_patches == 16
+
+    cls_tokens = timesformer(imgs)
+    assert cls_tokens.shape == torch.Size([1, 768])
+
+    # space_only
+    timesformer = TimeSformer(
+        8, 64, 16, embed_dims=512, num_heads=8, attention_type='space_only')
+    timesformer.init_weights()
+
+    assert not hasattr(timesformer, 'time_embed')
+    assert timesformer.patch_embed.num_patches == 16
+
+    cls_tokens = timesformer(imgs)
+    assert cls_tokens.shape == torch.Size([1, 512])
+
+    # joint_space_time
+    input_shape = (1, 3, 2, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape)
+    timesformer = TimeSformer(
+        2,
+        64,
+        8,
+        embed_dims=256,
+        num_heads=8,
+        attention_type='joint_space_time')
+    timesformer.init_weights()
+
+    assert hasattr(timesformer, 'time_embed')
+    assert timesformer.patch_embed.num_patches == 64
+
+    cls_tokens = timesformer(imgs)
+    assert cls_tokens.shape == torch.Size([1, 256])
+
+    with pytest.raises(AssertionError):
+        # unsupported attention type
+        timesformer = TimeSformer(
+            8, 64, 16, attention_type='wrong_attention_type')
+
+    with pytest.raises(AssertionError):
+        # Wrong transformer_layers type
+        timesformer = TimeSformer(8, 64, 16, transformer_layers='wrong_type')
+
+
+def test_c3d_backbone():
+    """Test c3d backbone."""
+    input_shape = (1, 3, 16, 112, 112)
+    imgs = generate_backbone_demo_inputs(input_shape)
+
+    # c3d inference test
+    c3d = C3D()
+    c3d.init_weights()
+    c3d.train()
+    feat = c3d(imgs)
+    assert feat.shape == torch.Size([1, 4096])
+
+    # c3d with bn inference test
+    c3d_bn = C3D(norm_cfg=dict(type='BN3d'))
+    c3d_bn.init_weights()
+    c3d_bn.train()
+    feat = c3d_bn(imgs)
+    assert feat.shape == torch.Size([1, 4096])
+
+
+def test_resnet_audio_backbone():
+    """Test ResNetAudio backbone."""
+    input_shape = (1, 1, 16, 16)
+    spec = generate_backbone_demo_inputs(input_shape)
+    # inference
+    audioonly = ResNetAudio(50, None)
+    audioonly.init_weights()
+    audioonly.train()
+    feat = audioonly(spec)
+    assert feat.shape == torch.Size([1, 1024, 2, 2])
+
+
+@pytest.mark.skipif(
+    not torch.cuda.is_available(), reason='requires CUDA support')
+def test_resnet_tin_backbone():
+    """Test resnet_tin backbone."""
+    with pytest.raises(AssertionError):
+        # num_segments should be positive
+        resnet_tin = ResNetTIN(50, num_segments=-1)
+        resnet_tin.init_weights()
+
+    from mmaction.models.backbones.resnet_tin import (CombineNet,
+                                                      TemporalInterlace)
+
+    # resnet_tin with normal config
+    resnet_tin = ResNetTIN(50)
+    resnet_tin.init_weights()
+    for layer_name in resnet_tin.res_layers:
+        layer = getattr(resnet_tin, layer_name)
+        blocks = list(layer.children())
+        for block in blocks:
+            assert isinstance(block.conv1.conv, CombineNet)
+            assert isinstance(block.conv1.conv.net1, TemporalInterlace)
+            assert (
+                block.conv1.conv.net1.num_segments == resnet_tin.num_segments)
+            assert block.conv1.conv.net1.shift_div == resnet_tin.shift_div
+
+    # resnet_tin with partial batchnorm
+    resnet_tin_pbn = ResNetTIN(50, partial_bn=True)
+    resnet_tin_pbn.train()
+    count_bn = 0
+    for m in resnet_tin_pbn.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            count_bn += 1
+            if count_bn >= 2:
+                assert m.training is False
+                assert m.weight.requires_grad is False
+                assert m.bias.requires_grad is False
+            else:
+                assert m.training is True
+                assert m.weight.requires_grad is True
+                assert m.bias.requires_grad is True
+
+    input_shape = (8, 3, 64, 64)
+    imgs = generate_backbone_demo_inputs(input_shape).cuda()
+    resnet_tin = resnet_tin.cuda()
+
+    # resnet_tin with normal cfg inference
+    feat = resnet_tin(imgs)
+    assert feat.shape == torch.Size([8, 2048, 2, 2])
+
+
+def test_stgcn_backbone():
+    """Test STGCN backbone."""
+    # test coco layout, spatial strategy
+    input_shape = (1, 3, 300, 17, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='coco', strategy='spatial'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 17])
+
+    # test openpose layout, spatial strategy
+    input_shape = (1, 3, 300, 18, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='openpose', strategy='spatial'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 18])
+
+    # test ntu-rgb+d layout, spatial strategy
+    input_shape = (1, 3, 300, 25, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='ntu-rgb+d', strategy='spatial'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 25])
+
+    # test ntu_edge layout, spatial strategy
+    input_shape = (1, 3, 300, 24, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='ntu_edge', strategy='spatial'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 24])
+
+    # test coco layout, uniform strategy
+    input_shape = (1, 3, 300, 17, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='coco', strategy='uniform'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 17])
+
+    # test openpose layout, uniform strategy
+    input_shape = (1, 3, 300, 18, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='openpose', strategy='uniform'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 18])
+
+    # test ntu-rgb+d layout, uniform strategy
+    input_shape = (1, 3, 300, 25, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='ntu-rgb+d', strategy='uniform'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 25])
+
+    # test ntu_edge layout, uniform strategy
+    input_shape = (1, 3, 300, 24, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='ntu_edge', strategy='uniform'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 24])
+
+    # test coco layout, distance strategy
+    input_shape = (1, 3, 300, 17, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='coco', strategy='distance'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 17])
+
+    # test openpose layout, distance strategy
+    input_shape = (1, 3, 300, 18, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='openpose', strategy='distance'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 18])
+
+    # test ntu-rgb+d layout, distance strategy
+    input_shape = (1, 3, 300, 25, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='ntu-rgb+d', strategy='distance'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 25])
+
+    # test ntu_edge layout, distance strategy
+    input_shape = (1, 3, 300, 24, 2)
+    skeletons = generate_backbone_demo_inputs(input_shape)
+
+    stgcn = STGCN(
+        in_channels=3,
+        edge_importance_weighting=True,
+        graph_cfg=dict(layout='ntu_edge', strategy='distance'))
+    stgcn.init_weights()
+    stgcn.train()
+    feat = stgcn(skeletons)
+    assert feat.shape == torch.Size([2, 256, 75, 24])