[c1b1c5]: / ViTPose / tests / test_backbones / test_rsn.py

Download this file

36 lines (30 with data), 1.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmpose.models import RSN
def test_rsn_backbone():
with pytest.raises(AssertionError):
# RSN's num_stages should larger than 0
RSN(num_stages=0)
with pytest.raises(AssertionError):
# RSN's num_steps should larger than 1
RSN(num_steps=1)
with pytest.raises(AssertionError):
# RSN's num_units should larger than 1
RSN(num_units=1)
with pytest.raises(AssertionError):
# len(num_blocks) should equal num_units
RSN(num_units=2, num_blocks=[2, 2, 2])
# Test RSN's outputs
model = RSN(num_stages=2, num_units=2, num_blocks=[2, 2])
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 511, 511)
feat = model(imgs)
assert len(feat) == 2
assert len(feat[0]) == 2
assert len(feat[1]) == 2
assert feat[0][0].shape == torch.Size([1, 256, 64, 64])
assert feat[0][1].shape == torch.Size([1, 256, 128, 128])
assert feat[1][0].shape == torch.Size([1, 256, 64, 64])
assert feat[1][1].shape == torch.Size([1, 256, 128, 128])