[c1b1c5]: / ViTPose / tests / test_post_processing / test_group.py

Download this file

73 lines (67 with data), 2.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpose.core.post_processing.group import HeatmapParser
def test_group():
cfg = {}
cfg['num_joints'] = 17
cfg['detection_threshold'] = 0.1
cfg['tag_threshold'] = 1
cfg['use_detection_val'] = True
cfg['ignore_too_much'] = False
cfg['nms_kernel'] = 5
cfg['nms_padding'] = 2
cfg['tag_per_joint'] = True
cfg['max_num_people'] = 1
parser = HeatmapParser(cfg)
fake_heatmap = torch.zeros(1, 1, 5, 5)
fake_heatmap[0, 0, 3, 3] = 1
fake_heatmap[0, 0, 3, 2] = 0.8
assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0
fake_heatmap = torch.zeros(1, 17, 32, 32)
fake_tag = torch.zeros(1, 17, 32, 32, 1)
fake_heatmap[0, 0, 10, 10] = 0.8
fake_heatmap[0, 1, 12, 12] = 0.9
fake_heatmap[0, 4, 8, 8] = 0.8
fake_heatmap[0, 8, 6, 6] = 0.9
fake_tag[0, 0, 10, 10] = 0.8
fake_tag[0, 1, 12, 12] = 0.9
fake_tag[0, 4, 8, 8] = 0.8
fake_tag[0, 8, 6, 6] = 0.9
grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
assert grouped[0][0, 0, 0] == 10.25
assert abs(scores[0] - 0.2) < 0.001
cfg['tag_per_joint'] = False
parser = HeatmapParser(cfg)
grouped, scores = parser.parse(fake_heatmap, fake_tag, False, False)
assert grouped[0][0, 0, 0] == 10.
grouped, scores = parser.parse(fake_heatmap, fake_tag, False, True)
assert grouped[0][0, 0, 0] == 10.
def test_group_score_per_joint():
cfg = {}
cfg['num_joints'] = 17
cfg['detection_threshold'] = 0.1
cfg['tag_threshold'] = 1
cfg['use_detection_val'] = True
cfg['ignore_too_much'] = False
cfg['nms_kernel'] = 5
cfg['nms_padding'] = 2
cfg['tag_per_joint'] = True
cfg['max_num_people'] = 1
cfg['score_per_joint'] = True
parser = HeatmapParser(cfg)
fake_heatmap = torch.zeros(1, 1, 5, 5)
fake_heatmap[0, 0, 3, 3] = 1
fake_heatmap[0, 0, 3, 2] = 0.8
assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0
fake_heatmap = torch.zeros(1, 17, 32, 32)
fake_tag = torch.zeros(1, 17, 32, 32, 1)
fake_heatmap[0, 0, 10, 10] = 0.8
fake_heatmap[0, 1, 12, 12] = 0.9
fake_heatmap[0, 4, 8, 8] = 0.8
fake_heatmap[0, 8, 6, 6] = 0.9
fake_tag[0, 0, 10, 10] = 0.8
fake_tag[0, 1, 12, 12] = 0.9
fake_tag[0, 4, 8, 8] = 0.8
fake_tag[0, 8, 6, 6] = 0.9
grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
assert len(scores[0]) == 17