--- a +++ b/tests/test_models/test_neck.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import pytest +import torch + +from mmaction.models import TPN +from .base import generate_backbone_demo_inputs + + +def test_tpn(): + """Test TPN backbone.""" + + tpn_cfg = dict( + in_channels=(1024, 2048), + out_channels=1024, + spatial_modulation_cfg=dict( + in_channels=(1024, 2048), out_channels=2048), + temporal_modulation_cfg=dict(downsample_scales=(8, 8)), + upsample_cfg=dict(scale_factor=(1, 1, 1)), + downsample_cfg=dict(downsample_scale=(1, 1, 1)), + level_fusion_cfg=dict( + in_channels=(1024, 1024), + mid_channels=(1024, 1024), + out_channels=2048, + downsample_scales=((1, 1, 1), (1, 1, 1))), + aux_head_cfg=dict(out_channels=400, loss_weight=0.5)) + + with pytest.raises(AssertionError): + tpn_cfg_ = copy.deepcopy(tpn_cfg) + tpn_cfg_['in_channels'] = list(tpn_cfg_['in_channels']) + TPN(**tpn_cfg_) + + with pytest.raises(AssertionError): + tpn_cfg_ = copy.deepcopy(tpn_cfg) + tpn_cfg_['out_channels'] = float(tpn_cfg_['out_channels']) + TPN(**tpn_cfg_) + + with pytest.raises(AssertionError): + tpn_cfg_ = copy.deepcopy(tpn_cfg) + tpn_cfg_['downsample_cfg']['downsample_position'] = 'unsupport' + TPN(**tpn_cfg_) + + for k in tpn_cfg: + if not k.endswith('_cfg'): + continue + tpn_cfg_ = copy.deepcopy(tpn_cfg) + tpn_cfg_[k] = list() + with pytest.raises(AssertionError): + TPN(**tpn_cfg_) + + with pytest.raises(ValueError): + tpn_cfg_ = copy.deepcopy(tpn_cfg) + tpn_cfg_['flow_type'] = 'unsupport' + TPN(**tpn_cfg_) + + target_shape = (32, 1) + target = generate_backbone_demo_inputs(target_shape).long().squeeze() + x0_shape = (32, 1024, 1, 4, 4) + x1_shape = (32, 2048, 1, 2, 2) + x0 = generate_backbone_demo_inputs(x0_shape) + x1 = generate_backbone_demo_inputs(x1_shape) + x = [x0, x1] + + # ResNetTPN with 'cascade' flow_type + tpn_cfg_ = copy.deepcopy(tpn_cfg) + tpn_cascade = TPN(**tpn_cfg_) + feat, loss_aux = tpn_cascade(x, target) + assert feat.shape == torch.Size([32, 2048, 1, 2, 2]) + assert len(loss_aux) == 1 + + # ResNetTPN with 'parallel' flow_type + tpn_cfg_ = copy.deepcopy(tpn_cfg) + tpn_parallel = TPN(flow_type='parallel', **tpn_cfg_) + feat, loss_aux = tpn_parallel(x, target) + assert feat.shape == torch.Size([32, 2048, 1, 2, 2]) + assert len(loss_aux) == 1 + + # ResNetTPN with 'cascade' flow_type and target is None + feat, loss_aux = tpn_cascade(x, None) + assert feat.shape == torch.Size([32, 2048, 1, 2, 2]) + assert len(loss_aux) == 0 + + # ResNetTPN with 'parallel' flow_type and target is None + feat, loss_aux = tpn_parallel(x, None) + assert feat.shape == torch.Size([32, 2048, 1, 2, 2]) + assert len(loss_aux) == 0