[f1e01c]: / tests / test_models / test_backbones / test_blocks.py

Download this file

171 lines (145 with data), 6.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import pytest
import torch
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
make_divisible)
def test_make_divisible():
# test with min_value = None
assert make_divisible(10, 4) == 12
assert make_divisible(9, 4) == 12
assert make_divisible(1, 4) == 4
# test with min_value = 8
assert make_divisible(10, 4, 8) == 12
assert make_divisible(9, 4, 8) == 12
assert make_divisible(1, 4, 8) == 8
def test_inv_residual():
with pytest.raises(AssertionError):
# test stride assertion.
InvertedResidual(32, 32, 3, 4)
# test default config with res connection.
# set expand_ratio = 4, stride = 1 and inp=oup.
inv_module = InvertedResidual(32, 32, 1, 4)
assert inv_module.use_res_connect
assert inv_module.conv[0].kernel_size == (1, 1)
assert inv_module.conv[0].padding == 0
assert inv_module.conv[1].kernel_size == (3, 3)
assert inv_module.conv[1].padding == 1
assert inv_module.conv[0].with_norm
assert inv_module.conv[1].with_norm
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 64, 64)
# test inv_residual module without res connection.
# set expand_ratio = 4, stride = 2.
inv_module = InvertedResidual(32, 32, 2, 4)
assert not inv_module.use_res_connect
assert inv_module.conv[0].kernel_size == (1, 1)
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 32, 32)
# test expand_ratio == 1
inv_module = InvertedResidual(32, 32, 1, 1)
assert inv_module.conv[0].kernel_size == (3, 3)
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 64, 64)
# test with checkpoint forward
inv_module = InvertedResidual(32, 32, 1, 1, with_cp=True)
assert inv_module.with_cp
x = torch.rand(1, 32, 64, 64, requires_grad=True)
output = inv_module(x)
assert output.shape == (1, 32, 64, 64)
def test_inv_residualv3():
with pytest.raises(AssertionError):
# test stride assertion.
InvertedResidualV3(32, 32, 16, stride=3)
with pytest.raises(AssertionError):
# test assertion.
InvertedResidualV3(32, 32, 16, with_expand_conv=False)
# test with se_cfg=None, with_expand_conv=False
inv_module = InvertedResidualV3(32, 32, 32, with_expand_conv=False)
assert inv_module.with_res_shortcut is True
assert inv_module.with_se is False
assert inv_module.with_expand_conv is False
assert not hasattr(inv_module, 'expand_conv')
assert isinstance(inv_module.depthwise_conv.conv, torch.nn.Conv2d)
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
assert inv_module.depthwise_conv.conv.stride == (1, 1)
assert inv_module.depthwise_conv.conv.padding == (1, 1)
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
assert isinstance(inv_module.depthwise_conv.activate, torch.nn.ReLU)
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
assert inv_module.linear_conv.conv.stride == (1, 1)
assert inv_module.linear_conv.conv.padding == (0, 0)
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 32, 64, 64)
# test with se_cfg and with_expand_conv
se_cfg = dict(
channels=16,
ratio=4,
act_cfg=(dict(type='ReLU'),
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
act_cfg = dict(type='HSwish')
inv_module = InvertedResidualV3(
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg)
assert inv_module.with_res_shortcut is False
assert inv_module.with_se is True
assert inv_module.with_expand_conv is True
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
assert inv_module.expand_conv.conv.stride == (1, 1)
assert inv_module.expand_conv.conv.padding == (0, 0)
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
assert isinstance(inv_module.depthwise_conv.conv,
mmcv.cnn.bricks.Conv2dAdaptivePadding)
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
assert inv_module.depthwise_conv.conv.stride == (2, 2)
assert inv_module.depthwise_conv.conv.padding == (0, 0)
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
assert inv_module.linear_conv.conv.stride == (1, 1)
assert inv_module.linear_conv.conv.padding == (0, 0)
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
x = torch.rand(1, 32, 64, 64)
output = inv_module(x)
assert output.shape == (1, 40, 32, 32)
# test with checkpoint forward
inv_module = InvertedResidualV3(
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg, with_cp=True)
assert inv_module.with_cp
x = torch.randn(2, 32, 64, 64, requires_grad=True)
output = inv_module(x)
assert output.shape == (2, 40, 32, 32)
def test_se_layer():
with pytest.raises(AssertionError):
# test act_cfg assertion.
SELayer(32, act_cfg=(dict(type='ReLU'), ))
# test config with channels = 16.
se_layer = SELayer(16)
assert se_layer.conv1.conv.kernel_size == (1, 1)
assert se_layer.conv1.conv.stride == (1, 1)
assert se_layer.conv1.conv.padding == (0, 0)
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
assert se_layer.conv2.conv.kernel_size == (1, 1)
assert se_layer.conv2.conv.stride == (1, 1)
assert se_layer.conv2.conv.padding == (0, 0)
assert isinstance(se_layer.conv2.activate, mmcv.cnn.HSigmoid)
x = torch.rand(1, 16, 64, 64)
output = se_layer(x)
assert output.shape == (1, 16, 64, 64)
# test config with channels = 16, act_cfg = dict(type='ReLU').
se_layer = SELayer(16, act_cfg=dict(type='ReLU'))
assert se_layer.conv1.conv.kernel_size == (1, 1)
assert se_layer.conv1.conv.stride == (1, 1)
assert se_layer.conv1.conv.padding == (0, 0)
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
assert se_layer.conv2.conv.kernel_size == (1, 1)
assert se_layer.conv2.conv.stride == (1, 1)
assert se_layer.conv2.conv.padding == (0, 0)
assert isinstance(se_layer.conv2.activate, torch.nn.ReLU)
x = torch.rand(1, 16, 64, 64)
output = se_layer(x)
assert output.shape == (1, 16, 64, 64)