[4e96d3]: / tests / test_models / test_backbones / test_stdc.py

Download this file

132 lines (116 with data), 4.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import STDCContextPathNet
from mmseg.models.backbones.stdc import (AttentionRefinementModule,
FeatureFusionModule, STDCModule,
STDCNet)
def test_stdc_context_path_net():
# Test STDCContextPathNet Standard Forward
model = STDCContextPathNet(
backbone_cfg=dict(
type='STDCNet',
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=True),
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 256, 512)
feat = model(imgs)
assert len(feat) == 4
# output for segment Head
assert feat[0].shape == torch.Size([batch_size, 256, 32, 64])
# for auxiliary head 1
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
# for auxiliary head 2
assert feat[2].shape == torch.Size([batch_size, 128, 32, 64])
# for auxiliary head 3
assert feat[3].shape == torch.Size([batch_size, 256, 32, 64])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 527, 279)
model = STDCContextPathNet(
backbone_cfg=dict(
type='STDCNet',
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='add',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False),
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
model.init_weights()
model.train()
feat = model(imgs)
assert len(feat) == 4
def test_stdcnet():
with pytest.raises(AssertionError):
# STDC backbone constraints.
STDCNet(
stdc_type='STDCNet3',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False)
with pytest.raises(AssertionError):
# STDC bottleneck type constraints.
STDCNet(
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='dog',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False)
with pytest.raises(AssertionError):
# STDC channels length constraints.
STDCNet(
stdc_type='STDCNet1',
in_channels=3,
channels=(16, 32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
with_final_conv=False)
def test_feature_fusion_module():
x_ffm = FeatureFusionModule(in_channels=64, out_channels=32)
assert x_ffm.conv0.in_channels == 64
assert x_ffm.attention[1].in_channels == 32
assert x_ffm.attention[2].in_channels == 8
assert x_ffm.attention[2].out_channels == 32
x1 = torch.randn(2, 32, 32, 64)
x2 = torch.randn(2, 32, 32, 64)
x_out = x_ffm(x1, x2)
assert x_out.shape == torch.Size([2, 32, 32, 64])
def test_attention_refinement_module():
x_arm = AttentionRefinementModule(128, 32)
assert x_arm.conv_layer.in_channels == 128
assert x_arm.atten_conv_layer[1].conv.out_channels == 32
x = torch.randn(2, 128, 32, 64)
x_out = x_arm(x)
assert x_out.shape == torch.Size([2, 32, 32, 64])
def test_stdc_module():
x_stdc = STDCModule(in_channels=32, out_channels=32, stride=4)
assert x_stdc.layers[0].conv.in_channels == 32
assert x_stdc.layers[3].conv.out_channels == 4
x = torch.randn(2, 32, 32, 64)
x_out = x_stdc(x)
assert x_out.shape == torch.Size([2, 32, 32, 64])