|
a |
|
b/ViTPose/mmpose/datasets/dataset_info.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import numpy as np |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
class DatasetInfo: |
|
|
6 |
|
|
|
7 |
def __init__(self, dataset_info): |
|
|
8 |
self._dataset_info = dataset_info |
|
|
9 |
self.dataset_name = self._dataset_info['dataset_name'] |
|
|
10 |
self.paper_info = self._dataset_info['paper_info'] |
|
|
11 |
self.keypoint_info = self._dataset_info['keypoint_info'] |
|
|
12 |
self.skeleton_info = self._dataset_info['skeleton_info'] |
|
|
13 |
self.joint_weights = np.array( |
|
|
14 |
self._dataset_info['joint_weights'], dtype=np.float32)[:, None] |
|
|
15 |
|
|
|
16 |
self.sigmas = np.array(self._dataset_info['sigmas']) |
|
|
17 |
|
|
|
18 |
self._parse_keypoint_info() |
|
|
19 |
self._parse_skeleton_info() |
|
|
20 |
|
|
|
21 |
def _parse_skeleton_info(self): |
|
|
22 |
"""Parse skeleton information. |
|
|
23 |
|
|
|
24 |
- link_num (int): number of links. |
|
|
25 |
- skeleton (list((2,))): list of links (id). |
|
|
26 |
- skeleton_name (list((2,))): list of links (name). |
|
|
27 |
- pose_link_color (np.ndarray): the color of the link for |
|
|
28 |
visualization. |
|
|
29 |
""" |
|
|
30 |
self.link_num = len(self.skeleton_info.keys()) |
|
|
31 |
self.pose_link_color = [] |
|
|
32 |
|
|
|
33 |
self.skeleton_name = [] |
|
|
34 |
self.skeleton = [] |
|
|
35 |
for skid in self.skeleton_info.keys(): |
|
|
36 |
link = self.skeleton_info[skid]['link'] |
|
|
37 |
self.skeleton_name.append(link) |
|
|
38 |
self.skeleton.append([ |
|
|
39 |
self.keypoint_name2id[link[0]], self.keypoint_name2id[link[1]] |
|
|
40 |
]) |
|
|
41 |
self.pose_link_color.append(self.skeleton_info[skid].get( |
|
|
42 |
'color', [255, 128, 0])) |
|
|
43 |
self.pose_link_color = np.array(self.pose_link_color) |
|
|
44 |
|
|
|
45 |
def _parse_keypoint_info(self): |
|
|
46 |
"""Parse keypoint information. |
|
|
47 |
|
|
|
48 |
- keypoint_num (int): number of keypoints. |
|
|
49 |
- keypoint_id2name (dict): mapping keypoint id to keypoint name. |
|
|
50 |
- keypoint_name2id (dict): mapping keypoint name to keypoint id. |
|
|
51 |
- upper_body_ids (list): a list of keypoints that belong to the |
|
|
52 |
upper body. |
|
|
53 |
- lower_body_ids (list): a list of keypoints that belong to the |
|
|
54 |
lower body. |
|
|
55 |
- flip_index (list): list of flip index (id) |
|
|
56 |
- flip_pairs (list((2,))): list of flip pairs (id) |
|
|
57 |
- flip_index_name (list): list of flip index (name) |
|
|
58 |
- flip_pairs_name (list((2,))): list of flip pairs (name) |
|
|
59 |
- pose_kpt_color (np.ndarray): the color of the keypoint for |
|
|
60 |
visualization. |
|
|
61 |
""" |
|
|
62 |
|
|
|
63 |
self.keypoint_num = len(self.keypoint_info.keys()) |
|
|
64 |
self.keypoint_id2name = {} |
|
|
65 |
self.keypoint_name2id = {} |
|
|
66 |
|
|
|
67 |
self.pose_kpt_color = [] |
|
|
68 |
self.upper_body_ids = [] |
|
|
69 |
self.lower_body_ids = [] |
|
|
70 |
|
|
|
71 |
self.flip_index_name = [] |
|
|
72 |
self.flip_pairs_name = [] |
|
|
73 |
|
|
|
74 |
for kid in self.keypoint_info.keys(): |
|
|
75 |
|
|
|
76 |
keypoint_name = self.keypoint_info[kid]['name'] |
|
|
77 |
self.keypoint_id2name[kid] = keypoint_name |
|
|
78 |
self.keypoint_name2id[keypoint_name] = kid |
|
|
79 |
self.pose_kpt_color.append(self.keypoint_info[kid].get( |
|
|
80 |
'color', [255, 128, 0])) |
|
|
81 |
|
|
|
82 |
type = self.keypoint_info[kid].get('type', '') |
|
|
83 |
if type == 'upper': |
|
|
84 |
self.upper_body_ids.append(kid) |
|
|
85 |
elif type == 'lower': |
|
|
86 |
self.lower_body_ids.append(kid) |
|
|
87 |
else: |
|
|
88 |
pass |
|
|
89 |
|
|
|
90 |
swap_keypoint = self.keypoint_info[kid].get('swap', '') |
|
|
91 |
if swap_keypoint == keypoint_name or swap_keypoint == '': |
|
|
92 |
self.flip_index_name.append(keypoint_name) |
|
|
93 |
else: |
|
|
94 |
self.flip_index_name.append(swap_keypoint) |
|
|
95 |
if [swap_keypoint, keypoint_name] not in self.flip_pairs_name: |
|
|
96 |
self.flip_pairs_name.append([keypoint_name, swap_keypoint]) |
|
|
97 |
|
|
|
98 |
self.flip_pairs = [[ |
|
|
99 |
self.keypoint_name2id[pair[0]], self.keypoint_name2id[pair[1]] |
|
|
100 |
] for pair in self.flip_pairs_name] |
|
|
101 |
self.flip_index = [ |
|
|
102 |
self.keypoint_name2id[name] for name in self.flip_index_name |
|
|
103 |
] |
|
|
104 |
self.pose_kpt_color = np.array(self.pose_kpt_color) |