a b/ViTPose/mmpose/datasets/dataset_info.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import numpy as np
3
4
5
class DatasetInfo:
6
7
    def __init__(self, dataset_info):
8
        self._dataset_info = dataset_info
9
        self.dataset_name = self._dataset_info['dataset_name']
10
        self.paper_info = self._dataset_info['paper_info']
11
        self.keypoint_info = self._dataset_info['keypoint_info']
12
        self.skeleton_info = self._dataset_info['skeleton_info']
13
        self.joint_weights = np.array(
14
            self._dataset_info['joint_weights'], dtype=np.float32)[:, None]
15
16
        self.sigmas = np.array(self._dataset_info['sigmas'])
17
18
        self._parse_keypoint_info()
19
        self._parse_skeleton_info()
20
21
    def _parse_skeleton_info(self):
22
        """Parse skeleton information.
23
24
        - link_num (int): number of links.
25
        - skeleton (list((2,))): list of links (id).
26
        - skeleton_name (list((2,))): list of links (name).
27
        - pose_link_color (np.ndarray): the color of the link for
28
            visualization.
29
        """
30
        self.link_num = len(self.skeleton_info.keys())
31
        self.pose_link_color = []
32
33
        self.skeleton_name = []
34
        self.skeleton = []
35
        for skid in self.skeleton_info.keys():
36
            link = self.skeleton_info[skid]['link']
37
            self.skeleton_name.append(link)
38
            self.skeleton.append([
39
                self.keypoint_name2id[link[0]], self.keypoint_name2id[link[1]]
40
            ])
41
            self.pose_link_color.append(self.skeleton_info[skid].get(
42
                'color', [255, 128, 0]))
43
        self.pose_link_color = np.array(self.pose_link_color)
44
45
    def _parse_keypoint_info(self):
46
        """Parse keypoint information.
47
48
        - keypoint_num (int): number of keypoints.
49
        - keypoint_id2name (dict): mapping keypoint id to keypoint name.
50
        - keypoint_name2id (dict): mapping keypoint name to keypoint id.
51
        - upper_body_ids (list): a list of keypoints that belong to the
52
            upper body.
53
        - lower_body_ids (list): a list of keypoints that belong to the
54
            lower body.
55
        - flip_index (list): list of flip index (id)
56
        - flip_pairs (list((2,))): list of flip pairs (id)
57
        - flip_index_name (list): list of flip index (name)
58
        - flip_pairs_name (list((2,))): list of flip pairs (name)
59
        - pose_kpt_color (np.ndarray): the color of the keypoint for
60
            visualization.
61
        """
62
63
        self.keypoint_num = len(self.keypoint_info.keys())
64
        self.keypoint_id2name = {}
65
        self.keypoint_name2id = {}
66
67
        self.pose_kpt_color = []
68
        self.upper_body_ids = []
69
        self.lower_body_ids = []
70
71
        self.flip_index_name = []
72
        self.flip_pairs_name = []
73
74
        for kid in self.keypoint_info.keys():
75
76
            keypoint_name = self.keypoint_info[kid]['name']
77
            self.keypoint_id2name[kid] = keypoint_name
78
            self.keypoint_name2id[keypoint_name] = kid
79
            self.pose_kpt_color.append(self.keypoint_info[kid].get(
80
                'color', [255, 128, 0]))
81
82
            type = self.keypoint_info[kid].get('type', '')
83
            if type == 'upper':
84
                self.upper_body_ids.append(kid)
85
            elif type == 'lower':
86
                self.lower_body_ids.append(kid)
87
            else:
88
                pass
89
90
            swap_keypoint = self.keypoint_info[kid].get('swap', '')
91
            if swap_keypoint == keypoint_name or swap_keypoint == '':
92
                self.flip_index_name.append(keypoint_name)
93
            else:
94
                self.flip_index_name.append(swap_keypoint)
95
                if [swap_keypoint, keypoint_name] not in self.flip_pairs_name:
96
                    self.flip_pairs_name.append([keypoint_name, swap_keypoint])
97
98
        self.flip_pairs = [[
99
            self.keypoint_name2id[pair[0]], self.keypoint_name2id[pair[1]]
100
        ] for pair in self.flip_pairs_name]
101
        self.flip_index = [
102
            self.keypoint_name2id[name] for name in self.flip_index_name
103
        ]
104
        self.pose_kpt_color = np.array(self.pose_kpt_color)