[c1b1c5]: / ViTPose / mmpose / datasets / dataset_info.py

Download this file

105 lines (86 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
class DatasetInfo:
def __init__(self, dataset_info):
self._dataset_info = dataset_info
self.dataset_name = self._dataset_info['dataset_name']
self.paper_info = self._dataset_info['paper_info']
self.keypoint_info = self._dataset_info['keypoint_info']
self.skeleton_info = self._dataset_info['skeleton_info']
self.joint_weights = np.array(
self._dataset_info['joint_weights'], dtype=np.float32)[:, None]
self.sigmas = np.array(self._dataset_info['sigmas'])
self._parse_keypoint_info()
self._parse_skeleton_info()
def _parse_skeleton_info(self):
"""Parse skeleton information.
- link_num (int): number of links.
- skeleton (list((2,))): list of links (id).
- skeleton_name (list((2,))): list of links (name).
- pose_link_color (np.ndarray): the color of the link for
visualization.
"""
self.link_num = len(self.skeleton_info.keys())
self.pose_link_color = []
self.skeleton_name = []
self.skeleton = []
for skid in self.skeleton_info.keys():
link = self.skeleton_info[skid]['link']
self.skeleton_name.append(link)
self.skeleton.append([
self.keypoint_name2id[link[0]], self.keypoint_name2id[link[1]]
])
self.pose_link_color.append(self.skeleton_info[skid].get(
'color', [255, 128, 0]))
self.pose_link_color = np.array(self.pose_link_color)
def _parse_keypoint_info(self):
"""Parse keypoint information.
- keypoint_num (int): number of keypoints.
- keypoint_id2name (dict): mapping keypoint id to keypoint name.
- keypoint_name2id (dict): mapping keypoint name to keypoint id.
- upper_body_ids (list): a list of keypoints that belong to the
upper body.
- lower_body_ids (list): a list of keypoints that belong to the
lower body.
- flip_index (list): list of flip index (id)
- flip_pairs (list((2,))): list of flip pairs (id)
- flip_index_name (list): list of flip index (name)
- flip_pairs_name (list((2,))): list of flip pairs (name)
- pose_kpt_color (np.ndarray): the color of the keypoint for
visualization.
"""
self.keypoint_num = len(self.keypoint_info.keys())
self.keypoint_id2name = {}
self.keypoint_name2id = {}
self.pose_kpt_color = []
self.upper_body_ids = []
self.lower_body_ids = []
self.flip_index_name = []
self.flip_pairs_name = []
for kid in self.keypoint_info.keys():
keypoint_name = self.keypoint_info[kid]['name']
self.keypoint_id2name[kid] = keypoint_name
self.keypoint_name2id[keypoint_name] = kid
self.pose_kpt_color.append(self.keypoint_info[kid].get(
'color', [255, 128, 0]))
type = self.keypoint_info[kid].get('type', '')
if type == 'upper':
self.upper_body_ids.append(kid)
elif type == 'lower':
self.lower_body_ids.append(kid)
else:
pass
swap_keypoint = self.keypoint_info[kid].get('swap', '')
if swap_keypoint == keypoint_name or swap_keypoint == '':
self.flip_index_name.append(keypoint_name)
else:
self.flip_index_name.append(swap_keypoint)
if [swap_keypoint, keypoint_name] not in self.flip_pairs_name:
self.flip_pairs_name.append([keypoint_name, swap_keypoint])
self.flip_pairs = [[
self.keypoint_name2id[pair[0]], self.keypoint_name2id[pair[1]]
] for pair in self.flip_pairs_name]
self.flip_index = [
self.keypoint_name2id[name] for name in self.flip_index_name
]
self.pose_kpt_color = np.array(self.pose_kpt_color)