[c1b1c5]: / ViTPose / tests / test_backbones / test_hourglass.py

Download this file

78 lines (61 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmpose.models import HourglassAENet, HourglassNet
def test_hourglass_backbone():
with pytest.raises(AssertionError):
# HourglassNet's num_stacks should larger than 0
HourglassNet(num_stacks=0)
with pytest.raises(AssertionError):
# len(stage_channels) should equal len(stage_blocks)
HourglassNet(
stage_channels=[256, 256, 384, 384, 384],
stage_blocks=[2, 2, 2, 2, 2, 4])
with pytest.raises(AssertionError):
# len(stage_channels) should larger than downsample_times
HourglassNet(
downsample_times=5,
stage_channels=[256, 256, 384, 384, 384],
stage_blocks=[2, 2, 2, 2, 2])
# Test HourglassNet-52
model = HourglassNet(num_stacks=1)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 256, 64, 64])
# Test HourglassNet-104
model = HourglassNet(num_stacks=2)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 2
assert feat[0].shape == torch.Size([1, 256, 64, 64])
assert feat[1].shape == torch.Size([1, 256, 64, 64])
def test_hourglass_ae_backbone():
with pytest.raises(AssertionError):
# HourglassAENet's num_stacks should larger than 0
HourglassAENet(num_stacks=0)
with pytest.raises(AssertionError):
# len(stage_channels) should larger than downsample_times
HourglassAENet(
downsample_times=5, stage_channels=[256, 256, 384, 384, 384])
# num_stack=1
model = HourglassAENet(num_stacks=1)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 34, 64, 64])
# num_stack=2
model = HourglassAENet(num_stacks=2)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 2
assert feat[0].shape == torch.Size([1, 34, 64, 64])
assert feat[1].shape == torch.Size([1, 34, 64, 64])