[f1e01c]: / tests / test_models / test_backbones / test_erfnet.py

Download this file

147 lines (133 with data), 5.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import ERFNet
from mmseg.models.backbones.erfnet import (DownsamplerBlock, NonBottleneck1d,
UpsamplerBlock)
def test_erfnet_backbone():
# Test ERFNet Standard Forward.
model = ERFNet(
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
)
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 256, 512)
output = model(imgs)
# output for segment Head
assert output[0].shape == torch.Size([batch_size, 16, 128, 256])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 527, 279)
output = model(imgs)
assert len(output[0]) == batch_size
with pytest.raises(AssertionError):
# Number of encoder downsample block and decoder upsample block.
ERFNet(
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(128, 64, 16),
dec_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
)
with pytest.raises(AssertionError):
# Number of encoder downsample block and encoder Non-bottleneck block.
ERFNet(
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8, 10),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
)
with pytest.raises(AssertionError):
# Number of encoder downsample block and
# channels of encoder Non-bottleneck block.
ERFNet(
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128, 256),
dec_upsample_channels=(64, 16),
dec_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
)
with pytest.raises(AssertionError):
# Number of encoder Non-bottleneck block and number of its channels.
ERFNet(
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8, 3),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
)
with pytest.raises(AssertionError):
# Number of decoder upsample block and decoder Non-bottleneck block.
ERFNet(
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_stages_non_bottleneck=(2, 2, 3),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
)
with pytest.raises(AssertionError):
# Number of decoder Non-bottleneck block and number of its channels.
ERFNet(
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_stage_non_bottlenecks=(5, 8),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16, 8),
dropout_ratio=0.1,
)
def test_erfnet_downsampler_block():
x_db = DownsamplerBlock(16, 64)
assert x_db.conv.in_channels == 16
assert x_db.conv.out_channels == 48
assert len(x_db.bn.weight) == 64
assert x_db.pool.kernel_size == 2
assert x_db.pool.stride == 2
def test_erfnet_non_bottleneck_1d():
x_nb1d = NonBottleneck1d(16, 0, 1)
assert x_nb1d.convs_layers[0].in_channels == 16
assert x_nb1d.convs_layers[0].out_channels == 16
assert x_nb1d.convs_layers[2].in_channels == 16
assert x_nb1d.convs_layers[2].out_channels == 16
assert x_nb1d.convs_layers[5].in_channels == 16
assert x_nb1d.convs_layers[5].out_channels == 16
assert x_nb1d.convs_layers[7].in_channels == 16
assert x_nb1d.convs_layers[7].out_channels == 16
assert x_nb1d.convs_layers[9].p == 0
def test_erfnet_upsampler_block():
x_ub = UpsamplerBlock(64, 16)
assert x_ub.conv.in_channels == 64
assert x_ub.conv.out_channels == 16
assert len(x_ub.bn.weight) == 16