[36ab12]: / ViTPose / tests / test_models / test_mesh_head.py

Download this file

77 lines (65 with data), 2.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmpose.models import HMRMeshHead
from mmpose.models.misc.discriminator import SMPLDiscriminator
def test_mesh_hmr_head():
"""Test hmr mesh head."""
head = HMRMeshHead(in_channels=512)
head.init_weights()
input_shape = (1, 512, 8, 8)
inputs = _demo_inputs(input_shape)
out = head(inputs)
smpl_rotmat, smpl_shape, camera = out
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
assert smpl_shape.shape == torch.Size([1, 10])
assert camera.shape == torch.Size([1, 3])
"""Test hmr mesh head with assigned mean parameters and n_iter """
head = HMRMeshHead(
in_channels=512,
smpl_mean_params='tests/data/smpl/smpl_mean_params.npz',
n_iter=3)
head.init_weights()
input_shape = (1, 512, 8, 8)
inputs = _demo_inputs(input_shape)
out = head(inputs)
smpl_rotmat, smpl_shape, camera = out
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
assert smpl_shape.shape == torch.Size([1, 10])
assert camera.shape == torch.Size([1, 3])
# test discriminator with SMPL pose parameters
# in rotation matrix representation
disc = SMPLDiscriminator(
beta_channel=(10, 10, 5, 1),
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
pred_theta = (camera, smpl_rotmat, smpl_shape)
pred_score = disc(pred_theta)
assert pred_score.shape[1] == 25
# test discriminator with SMPL pose parameters
# in axis-angle representation
pred_theta = (camera, camera.new_zeros([1, 72]), smpl_shape)
pred_score = disc(pred_theta)
assert pred_score.shape[1] == 25
with pytest.raises(TypeError):
_ = SMPLDiscriminator(
beta_channel=[10, 10, 5, 1],
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
with pytest.raises(ValueError):
_ = SMPLDiscriminator(
beta_channel=(10, ),
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
def _demo_inputs(input_shape=(1, 3, 64, 64)):
"""Create a superset of inputs needed to run mesh head.
Args:
input_shape (tuple): input batch dimensions.
Default: (1, 3, 64, 64).
Returns:
Random input tensor with the size of input_shape.
"""
inps = np.random.random(input_shape)
inps = torch.FloatTensor(inps)
return inps