[c1b1c5]: / ViTPose / tests / test_backbones / test_litehrnet.py

Download this file

144 lines (120 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.models.backbones import LiteHRNet
from mmpose.models.backbones.litehrnet import LiteHRModule
from mmpose.models.backbones.resnet import Bottleneck
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (_BatchNorm, )):
return True
return False
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
torch.zeros_like(modules.weight.data))
if hasattr(modules, 'bias'):
bias_zero = torch.equal(modules.bias.data,
torch.zeros_like(modules.bias.data))
else:
bias_zero = True
return weight_zero and bias_zero
def test_litehrmodule():
# Test LiteHRModule forward
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='LITE')
x = torch.randn(2, 40, 56, 56)
x_out = block([[x]])
assert x_out[0][0].shape == torch.Size([2, 40, 56, 56])
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='NAIVE')
x = torch.randn(2, 40, 56, 56)
x_out = block([x])
assert x_out[0].shape == torch.Size([2, 40, 56, 56])
with pytest.raises(ValueError):
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='none')
def test_litehrnet_backbone():
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('LITE', 'LITE', 'LITE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('NAIVE', 'NAIVE', 'NAIVE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])