[f1e01c]: / tests / test_models / test_backbones / test_hrnet.py

Download this file

145 lines (123 with data), 4.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.models.backbones.hrnet import HRModule, HRNet
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
@pytest.mark.parametrize('block', [BasicBlock, Bottleneck])
def test_hrmodule(block):
# Test multiscale forward
num_channles = (32, 64)
in_channels = [c * block.expansion for c in num_channles]
hrmodule = HRModule(
num_branches=2,
blocks=block,
in_channels=in_channels,
num_blocks=(4, 4),
num_channels=num_channles,
)
feats = [
torch.randn(1, in_channels[0], 64, 64),
torch.randn(1, in_channels[1], 32, 32)
]
feats = hrmodule(feats)
assert len(feats) == 2
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32])
# Test single scale forward
num_channles = (32, 64)
in_channels = [c * block.expansion for c in num_channles]
hrmodule = HRModule(
num_branches=2,
blocks=block,
in_channels=in_channels,
num_blocks=(4, 4),
num_channels=num_channles,
multiscale_output=False,
)
feats = [
torch.randn(1, in_channels[0], 64, 64),
torch.randn(1, in_channels[1], 32, 32)
]
feats = hrmodule(feats)
assert len(feats) == 1
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
def test_hrnet_backbone():
# only have 3 stages
extra = dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)))
with pytest.raises(AssertionError):
# HRNet now only support 4 stages
HRNet(extra=extra)
extra['stage4'] = dict(
num_modules=3,
num_branches=3, # should be 4
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256))
with pytest.raises(AssertionError):
# len(num_blocks) should equal num_branches
HRNet(extra=extra)
extra['stage4']['num_branches'] = 4
# Test hrnetv2p_w32
model = HRNet(extra=extra)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 64, 64)
feats = model(imgs)
assert len(feats) == 4
assert feats[0].shape == torch.Size([1, 32, 16, 16])
assert feats[3].shape == torch.Size([1, 256, 2, 2])
# Test single scale output
model = HRNet(extra=extra, multiscale_output=False)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 64, 64)
feats = model(imgs)
assert len(feats) == 1
assert feats[0].shape == torch.Size([1, 32, 16, 16])
# Test HRNET with two stage frozen
frozen_stages = 2
model = HRNet(extra, frozen_stages=frozen_stages)
model.init_weights()
model.train()
assert model.norm1.training is False
for layer in [model.conv1, model.norm1]:
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
if i == 1:
layer = getattr(model, f'layer{i}')
transition = getattr(model, f'transition{i}')
elif i == 4:
layer = getattr(model, f'stage{i}')
else:
layer = getattr(model, f'stage{i}')
transition = getattr(model, f'transition{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
for mod in transition.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in transition.parameters():
assert param.requires_grad is False