--- a
+++ b/ViTPose/mmpose/apis/inference.py
@@ -0,0 +1,833 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import warnings
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.parallel import collate, scatter
+from mmcv.runner import load_checkpoint
+from PIL import Image
+
+from mmpose.core.post_processing import oks_nms
+from mmpose.datasets.dataset_info import DatasetInfo
+from mmpose.datasets.pipelines import Compose
+from mmpose.models import build_posenet
+from mmpose.utils.hooks import OutputHook
+
+os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
+
+
+def init_pose_model(config, checkpoint=None, device='cuda:0'):
+    """Initialize a pose model from config file.
+
+    Args:
+        config (str or :obj:`mmcv.Config`): Config file path or the config
+            object.
+        checkpoint (str, optional): Checkpoint path. If left as None, the model
+            will not load any models.
+
+    Returns:
+        nn.Module: The constructed detector.
+    """
+    if isinstance(config, str):
+        config = mmcv.Config.fromfile(config)
+    elif not isinstance(config, mmcv.Config):
+        raise TypeError('config must be a filename or Config object, '
+                        f'but got {type(config)}')
+    config.model.pretrained = None
+    model = build_posenet(config.model)
+    if checkpoint is not None:
+        # load model checkpoint
+        load_checkpoint(model, checkpoint, map_location='cpu')
+    # save the config in the model for convenience
+    model.cfg = config
+    model.to(device)
+    model.eval()
+    return model
+
+
+def _xyxy2xywh(bbox_xyxy):
+    """Transform the bbox format from x1y1x2y2 to xywh.
+
+    Args:
+        bbox_xyxy (np.ndarray): Bounding boxes (with scores), shaped (n, 4) or
+            (n, 5). (left, top, right, bottom, [score])
+
+    Returns:
+        np.ndarray: Bounding boxes (with scores),
+          shaped (n, 4) or (n, 5). (left, top, width, height, [score])
+    """
+    bbox_xywh = bbox_xyxy.copy()
+    bbox_xywh[:, 2] = bbox_xywh[:, 2] - bbox_xywh[:, 0] + 1
+    bbox_xywh[:, 3] = bbox_xywh[:, 3] - bbox_xywh[:, 1] + 1
+
+    return bbox_xywh
+
+
+def _xywh2xyxy(bbox_xywh):
+    """Transform the bbox format from xywh to x1y1x2y2.
+
+    Args:
+        bbox_xywh (ndarray): Bounding boxes (with scores),
+            shaped (n, 4) or (n, 5). (left, top, width, height, [score])
+    Returns:
+        np.ndarray: Bounding boxes (with scores), shaped (n, 4) or
+          (n, 5). (left, top, right, bottom, [score])
+    """
+    bbox_xyxy = bbox_xywh.copy()
+    bbox_xyxy[:, 2] = bbox_xyxy[:, 2] + bbox_xyxy[:, 0] - 1
+    bbox_xyxy[:, 3] = bbox_xyxy[:, 3] + bbox_xyxy[:, 1] - 1
+
+    return bbox_xyxy
+
+
+def _box2cs(cfg, box):
+    """This encodes bbox(x,y,w,h) into (center, scale)
+
+    Args:
+        x, y, w, h
+
+    Returns:
+        tuple: A tuple containing center and scale.
+
+        - np.ndarray[float32](2,): Center of the bbox (x, y).
+        - np.ndarray[float32](2,): Scale of the bbox w & h.
+    """
+
+    x, y, w, h = box[:4]
+    input_size = cfg.data_cfg['image_size']
+    aspect_ratio = input_size[0] / input_size[1]
+    center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
+
+    if w > aspect_ratio * h:
+        h = w * 1.0 / aspect_ratio
+    elif w < aspect_ratio * h:
+        w = h * aspect_ratio
+
+    # pixel std is 200.0
+    scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
+    scale = scale * 1.25
+
+    return center, scale
+
+
+def _inference_single_pose_model(model,
+                                 img_or_path,
+                                 bboxes,
+                                 dataset='TopDownCocoDataset',
+                                 dataset_info=None,
+                                 return_heatmap=False):
+    """Inference human bounding boxes.
+
+    Note:
+        - num_bboxes: N
+        - num_keypoints: K
+
+    Args:
+        model (nn.Module): The loaded pose model.
+        img_or_path (str | np.ndarray): Image filename or loaded image.
+        bboxes (list | np.ndarray): All bounding boxes (with scores),
+            shaped (N, 4) or (N, 5). (left, top, width, height, [score])
+            where N is number of bounding boxes.
+        dataset (str): Dataset name. Deprecated.
+        dataset_info (DatasetInfo): A class containing all dataset info.
+        outputs (list[str] | tuple[str]): Names of layers whose output is
+            to be returned, default: None
+
+    Returns:
+        ndarray[NxKx3]: Predicted pose x, y, score.
+        heatmap[N, K, H, W]: Model output heatmap.
+    """
+
+    cfg = model.cfg
+    device = next(model.parameters()).device
+    if device.type == 'cpu':
+        device = -1
+
+    # build the data pipeline
+    test_pipeline = Compose(cfg.test_pipeline)
+
+    assert len(bboxes[0]) in [4, 5]
+
+    if dataset_info is not None:
+        dataset_name = dataset_info.dataset_name
+        flip_pairs = dataset_info.flip_pairs
+    else:
+        warnings.warn(
+            'dataset is deprecated.'
+            'Please set `dataset_info` in the config.'
+            'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
+            DeprecationWarning)
+        # TODO: These will be removed in the later versions.
+        if dataset in ('TopDownCocoDataset', 'TopDownOCHumanDataset',
+                       'AnimalMacaqueDataset'):
+            flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12],
+                          [13, 14], [15, 16]]
+        elif dataset == 'TopDownCocoWholeBodyDataset':
+            body = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12],
+                    [13, 14], [15, 16]]
+            foot = [[17, 20], [18, 21], [19, 22]]
+
+            face = [[23, 39], [24, 38], [25, 37], [26, 36], [27, 35], [28, 34],
+                    [29, 33], [30, 32], [40, 49], [41, 48], [42, 47], [43, 46],
+                    [44, 45], [54, 58], [55, 57], [59, 68], [60, 67], [61, 66],
+                    [62, 65], [63, 70], [64, 69], [71, 77], [72, 76], [73, 75],
+                    [78, 82], [79, 81], [83, 87], [84, 86], [88, 90]]
+
+            hand = [[91, 112], [92, 113], [93, 114], [94, 115], [95, 116],
+                    [96, 117], [97, 118], [98, 119], [99, 120], [100, 121],
+                    [101, 122], [102, 123], [103, 124], [104, 125], [105, 126],
+                    [106, 127], [107, 128], [108, 129], [109, 130], [110, 131],
+                    [111, 132]]
+            flip_pairs = body + foot + face + hand
+        elif dataset == 'TopDownAicDataset':
+            flip_pairs = [[0, 3], [1, 4], [2, 5], [6, 9], [7, 10], [8, 11]]
+        elif dataset == 'TopDownMpiiDataset':
+            flip_pairs = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
+        elif dataset == 'TopDownMpiiTrbDataset':
+            flip_pairs = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11],
+                          [14, 15], [16, 22], [28, 34], [17, 23], [29, 35],
+                          [18, 24], [30, 36], [19, 25], [31, 37], [20, 26],
+                          [32, 38], [21, 27], [33, 39]]
+        elif dataset in ('OneHand10KDataset', 'FreiHandDataset',
+                         'PanopticDataset', 'InterHand2DDataset'):
+            flip_pairs = []
+        elif dataset in 'Face300WDataset':
+            flip_pairs = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11],
+                          [6, 10], [7, 9], [17, 26], [18, 25], [19, 24],
+                          [20, 23], [21, 22], [31, 35], [32, 34], [36, 45],
+                          [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
+                          [48, 54], [49, 53], [50, 52], [61, 63], [60, 64],
+                          [67, 65], [58, 56], [59, 55]]
+
+        elif dataset in 'FaceAFLWDataset':
+            flip_pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9],
+                          [12, 14], [15, 17]]
+
+        elif dataset in 'FaceCOFWDataset':
+            flip_pairs = [[0, 1], [4, 6], [2, 3], [5, 7], [8, 9], [10, 11],
+                          [12, 14], [16, 17], [13, 15], [18, 19], [22, 23]]
+
+        elif dataset in 'FaceWFLWDataset':
+            flip_pairs = [[0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27],
+                          [6, 26], [7, 25], [8, 24], [9, 23], [10, 22],
+                          [11, 21], [12, 20], [13, 19], [14, 18], [15, 17],
+                          [33, 46], [34, 45], [35, 44], [36, 43], [37, 42],
+                          [38, 50], [39, 49], [40, 48], [41, 47], [60, 72],
+                          [61, 71], [62, 70], [63, 69], [64, 68], [65, 75],
+                          [66, 74], [67, 73], [55, 59], [56, 58], [76, 82],
+                          [77, 81], [78, 80], [87, 83], [86, 84], [88, 92],
+                          [89, 91], [95, 93], [96, 97]]
+
+        elif dataset in 'AnimalFlyDataset':
+            flip_pairs = [[1, 2], [6, 18], [7, 19], [8, 20], [9, 21], [10, 22],
+                          [11, 23], [12, 24], [13, 25], [14, 26], [15, 27],
+                          [16, 28], [17, 29], [30, 31]]
+        elif dataset in 'AnimalHorse10Dataset':
+            flip_pairs = []
+
+        elif dataset in 'AnimalLocustDataset':
+            flip_pairs = [[5, 20], [6, 21], [7, 22], [8, 23], [9, 24],
+                          [10, 25], [11, 26], [12, 27], [13, 28], [14, 29],
+                          [15, 30], [16, 31], [17, 32], [18, 33], [19, 34]]
+
+        elif dataset in 'AnimalZebraDataset':
+            flip_pairs = [[3, 4], [5, 6]]
+
+        elif dataset in 'AnimalPoseDataset':
+            flip_pairs = [[0, 1], [2, 3], [8, 9], [10, 11], [12, 13], [14, 15],
+                          [16, 17], [18, 19]]
+        else:
+            raise NotImplementedError()
+        dataset_name = dataset
+
+    batch_data = []
+    for bbox in bboxes:
+        center, scale = _box2cs(cfg, bbox)
+
+        # prepare data
+        data = {
+            'center':
+            center,
+            'scale':
+            scale,
+            'bbox_score':
+            bbox[4] if len(bbox) == 5 else 1,
+            'bbox_id':
+            0,  # need to be assigned if batch_size > 1
+            'dataset':
+            dataset_name,
+            'joints_3d':
+            np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
+            'joints_3d_visible':
+            np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
+            'rotation':
+            0,
+            'ann_info': {
+                'image_size': np.array(cfg.data_cfg['image_size']),
+                'num_joints': cfg.data_cfg['num_joints'],
+                'flip_pairs': flip_pairs
+            }
+        }
+        if isinstance(img_or_path, np.ndarray):
+            data['img'] = img_or_path
+        else:
+            data['image_file'] = img_or_path
+
+        data = test_pipeline(data)
+        batch_data.append(data)
+
+    batch_data = collate(batch_data, samples_per_gpu=len(batch_data))
+    batch_data = scatter(batch_data, [device])[0]
+
+    # forward the model
+    with torch.no_grad():
+        result = model(
+            img=batch_data['img'],
+            img_metas=batch_data['img_metas'],
+            return_loss=False,
+            return_heatmap=return_heatmap)
+
+    return result['preds'], result['output_heatmap']
+
+
+def inference_top_down_pose_model(model,
+                                  img_or_path,
+                                  person_results=None,
+                                  bbox_thr=None,
+                                  format='xywh',
+                                  dataset='TopDownCocoDataset',
+                                  dataset_info=None,
+                                  return_heatmap=False,
+                                  outputs=None):
+    """Inference a single image with a list of person bounding boxes.
+
+    Note:
+        - num_people: P
+        - num_keypoints: K
+        - bbox height: H
+        - bbox width: W
+
+    Args:
+        model (nn.Module): The loaded pose model.
+        img_or_path (str| np.ndarray): Image filename or loaded image.
+        person_results (list(dict), optional): a list of detected persons that
+            contains ``bbox`` and/or ``track_id``:
+
+            - ``bbox`` (4, ) or (5, ): The person bounding box, which contains
+                4 box coordinates (and score).
+            - ``track_id`` (int): The unique id for each human instance. If
+                not provided, a dummy person result with a bbox covering
+                the entire image will be used. Default: None.
+        bbox_thr (float | None): Threshold for bounding boxes. Only bboxes
+            with higher scores will be fed into the pose detector.
+            If bbox_thr is None, all boxes will be used.
+        format (str): bbox format ('xyxy' | 'xywh'). Default: 'xywh'.
+
+            - `xyxy` means (left, top, right, bottom),
+            - `xywh` means (left, top, width, height).
+        dataset (str): Dataset name, e.g. 'TopDownCocoDataset'.
+            It is deprecated. Please use dataset_info instead.
+        dataset_info (DatasetInfo): A class containing all dataset info.
+        return_heatmap (bool) : Flag to return heatmap, default: False
+        outputs (list(str) | tuple(str)) : Names of layers whose outputs
+            need to be returned. Default: None.
+
+    Returns:
+        tuple:
+        - pose_results (list[dict]): The bbox & pose info. \
+            Each item in the list is a dictionary, \
+            containing the bbox: (left, top, right, bottom, [score]) \
+            and the pose (ndarray[Kx3]): x, y, score.
+        - returned_outputs (list[dict[np.ndarray[N, K, H, W] | \
+            torch.Tensor[N, K, H, W]]]): \
+            Output feature maps from layers specified in `outputs`. \
+            Includes 'heatmap' if `return_heatmap` is True.
+    """
+    # get dataset info
+    if (dataset_info is None and hasattr(model, 'cfg')
+            and 'dataset_info' in model.cfg):
+        dataset_info = DatasetInfo(model.cfg.dataset_info)
+    if dataset_info is None:
+        warnings.warn(
+            'dataset is deprecated.'
+            'Please set `dataset_info` in the config.'
+            'Check https://github.com/open-mmlab/mmpose/pull/663'
+            ' for details.', DeprecationWarning)
+
+    # only two kinds of bbox format is supported.
+    assert format in ['xyxy', 'xywh']
+
+    pose_results = []
+    returned_outputs = []
+
+    if person_results is None:
+        # create dummy person results
+        if isinstance(img_or_path, str):
+            width, height = Image.open(img_or_path).size
+        else:
+            height, width = img_or_path.shape[:2]
+        person_results = [{'bbox': np.array([0, 0, width, height])}]
+
+    if len(person_results) == 0:
+        return pose_results, returned_outputs
+
+    # Change for-loop preprocess each bbox to preprocess all bboxes at once.
+    bboxes = np.array([box['bbox'] for box in person_results])
+
+    # Select bboxes by score threshold
+    if bbox_thr is not None:
+        assert bboxes.shape[1] == 5
+        valid_idx = np.where(bboxes[:, 4] > bbox_thr)[0]
+        bboxes = bboxes[valid_idx]
+        person_results = [person_results[i] for i in valid_idx]
+
+    if format == 'xyxy':
+        bboxes_xyxy = bboxes
+        bboxes_xywh = _xyxy2xywh(bboxes)
+    else:
+        # format is already 'xywh'
+        bboxes_xywh = bboxes
+        bboxes_xyxy = _xywh2xyxy(bboxes)
+
+    # if bbox_thr remove all bounding box
+    if len(bboxes_xywh) == 0:
+        return [], []
+
+    with OutputHook(model, outputs=outputs, as_tensor=False) as h:
+        # poses is results['pred'] # N x 17x 3
+        poses, heatmap = _inference_single_pose_model(
+            model,
+            img_or_path,
+            bboxes_xywh,
+            dataset=dataset,
+            dataset_info=dataset_info,
+            return_heatmap=return_heatmap)
+
+        if return_heatmap:
+            h.layer_outputs['heatmap'] = heatmap
+
+        returned_outputs.append(h.layer_outputs)
+
+    assert len(poses) == len(person_results), print(
+        len(poses), len(person_results), len(bboxes_xyxy))
+    for pose, person_result, bbox_xyxy in zip(poses, person_results,
+                                              bboxes_xyxy):
+        pose_result = person_result.copy()
+        pose_result['keypoints'] = pose
+        pose_result['bbox'] = bbox_xyxy
+        pose_results.append(pose_result)
+
+    return pose_results, returned_outputs
+
+
+def inference_bottom_up_pose_model(model,
+                                   img_or_path,
+                                   dataset='BottomUpCocoDataset',
+                                   dataset_info=None,
+                                   pose_nms_thr=0.9,
+                                   return_heatmap=False,
+                                   outputs=None):
+    """Inference a single image with a bottom-up pose model.
+
+    Note:
+        - num_people: P
+        - num_keypoints: K
+        - bbox height: H
+        - bbox width: W
+
+    Args:
+        model (nn.Module): The loaded pose model.
+        img_or_path (str| np.ndarray): Image filename or loaded image.
+        dataset (str): Dataset name, e.g. 'BottomUpCocoDataset'.
+            It is deprecated. Please use dataset_info instead.
+        dataset_info (DatasetInfo): A class containing all dataset info.
+        pose_nms_thr (float): retain oks overlap < pose_nms_thr, default: 0.9.
+        return_heatmap (bool) : Flag to return heatmap, default: False.
+        outputs (list(str) | tuple(str)) : Names of layers whose outputs
+            need to be returned, default: None.
+
+    Returns:
+        tuple:
+        - pose_results (list[np.ndarray]): The predicted pose info. \
+            The length of the list is the number of people (P). \
+            Each item in the list is a ndarray, containing each \
+            person's pose (np.ndarray[Kx3]): x, y, score.
+        - returned_outputs (list[dict[np.ndarray[N, K, H, W] | \
+            torch.Tensor[N, K, H, W]]]): \
+            Output feature maps from layers specified in `outputs`. \
+            Includes 'heatmap' if `return_heatmap` is True.
+    """
+    # get dataset info
+    if (dataset_info is None and hasattr(model, 'cfg')
+            and 'dataset_info' in model.cfg):
+        dataset_info = DatasetInfo(model.cfg.dataset_info)
+
+    if dataset_info is not None:
+        dataset_name = dataset_info.dataset_name
+        flip_index = dataset_info.flip_index
+        sigmas = getattr(dataset_info, 'sigmas', None)
+    else:
+        warnings.warn(
+            'dataset is deprecated.'
+            'Please set `dataset_info` in the config.'
+            'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
+            DeprecationWarning)
+        assert (dataset == 'BottomUpCocoDataset')
+        dataset_name = dataset
+        flip_index = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
+        sigmas = None
+
+    pose_results = []
+    returned_outputs = []
+
+    cfg = model.cfg
+    device = next(model.parameters()).device
+    if device.type == 'cpu':
+        device = -1
+
+    # build the data pipeline
+    test_pipeline = Compose(cfg.test_pipeline)
+
+    # prepare data
+    data = {
+        'dataset': dataset_name,
+        'ann_info': {
+            'image_size': np.array(cfg.data_cfg['image_size']),
+            'num_joints': cfg.data_cfg['num_joints'],
+            'flip_index': flip_index,
+        }
+    }
+    if isinstance(img_or_path, np.ndarray):
+        data['img'] = img_or_path
+    else:
+        data['image_file'] = img_or_path
+
+    data = test_pipeline(data)
+    data = collate([data], samples_per_gpu=1)
+    data = scatter(data, [device])[0]
+
+    with OutputHook(model, outputs=outputs, as_tensor=False) as h:
+        # forward the model
+        with torch.no_grad():
+            result = model(
+                img=data['img'],
+                img_metas=data['img_metas'],
+                return_loss=False,
+                return_heatmap=return_heatmap)
+
+        if return_heatmap:
+            h.layer_outputs['heatmap'] = result['output_heatmap']
+
+        returned_outputs.append(h.layer_outputs)
+
+        for idx, pred in enumerate(result['preds']):
+            area = (np.max(pred[:, 0]) - np.min(pred[:, 0])) * (
+                np.max(pred[:, 1]) - np.min(pred[:, 1]))
+            pose_results.append({
+                'keypoints': pred[:, :3],
+                'score': result['scores'][idx],
+                'area': area,
+            })
+
+        # pose nms
+        score_per_joint = cfg.model.test_cfg.get('score_per_joint', False)
+        keep = oks_nms(
+            pose_results,
+            pose_nms_thr,
+            sigmas,
+            score_per_joint=score_per_joint)
+        pose_results = [pose_results[_keep] for _keep in keep]
+
+    return pose_results, returned_outputs
+
+
+def vis_pose_result(model,
+                    img,
+                    result,
+                    radius=4,
+                    thickness=1,
+                    kpt_score_thr=0.3,
+                    bbox_color='green',
+                    dataset='TopDownCocoDataset',
+                    dataset_info=None,
+                    show=False,
+                    out_file=None):
+    """Visualize the detection results on the image.
+
+    Args:
+        model (nn.Module): The loaded detector.
+        img (str | np.ndarray): Image filename or loaded image.
+        result (list[dict]): The results to draw over `img`
+                (bbox_result, pose_result).
+        radius (int): Radius of circles.
+        thickness (int): Thickness of lines.
+        kpt_score_thr (float): The threshold to visualize the keypoints.
+        skeleton (list[tuple()]): Default None.
+        show (bool):  Whether to show the image. Default True.
+        out_file (str|None): The filename of the output visualization image.
+    """
+
+    # get dataset info
+    if (dataset_info is None and hasattr(model, 'cfg')
+            and 'dataset_info' in model.cfg):
+        dataset_info = DatasetInfo(model.cfg.dataset_info)
+
+    if dataset_info is not None:
+        skeleton = dataset_info.skeleton
+        pose_kpt_color = dataset_info.pose_kpt_color
+        pose_link_color = dataset_info.pose_link_color
+    else:
+        warnings.warn(
+            'dataset is deprecated.'
+            'Please set `dataset_info` in the config.'
+            'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
+            DeprecationWarning)
+        # TODO: These will be removed in the later versions.
+        palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
+                            [230, 230, 0], [255, 153, 255], [153, 204, 255],
+                            [255, 102, 255], [255, 51, 255], [102, 178, 255],
+                            [51, 153, 255], [255, 153, 153], [255, 102, 102],
+                            [255, 51, 51], [153, 255, 153], [102, 255, 102],
+                            [51, 255, 51], [0, 255, 0], [0, 0, 255],
+                            [255, 0, 0], [255, 255, 255]])
+
+        if dataset in ('TopDownCocoDataset', 'BottomUpCocoDataset',
+                       'TopDownOCHumanDataset', 'AnimalMacaqueDataset'):
+            # show the results
+            skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
+                        [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
+                        [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4],
+                        [3, 5], [4, 6]]
+
+            pose_link_color = palette[[
+                0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16
+            ]]
+            pose_kpt_color = palette[[
+                16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0
+            ]]
+
+        elif dataset == 'TopDownCocoWholeBodyDataset':
+            # show the results
+            skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
+                        [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
+                        [8, 10], [1, 2], [0, 1], [0, 2],
+                        [1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18],
+                        [15, 19], [16, 20], [16, 21], [16, 22], [91, 92],
+                        [92, 93], [93, 94], [94, 95], [91, 96], [96, 97],
+                        [97, 98], [98, 99], [91, 100], [100, 101], [101, 102],
+                        [102, 103], [91, 104], [104, 105], [105, 106],
+                        [106, 107], [91, 108], [108, 109], [109, 110],
+                        [110, 111], [112, 113], [113, 114], [114, 115],
+                        [115, 116], [112, 117], [117, 118], [118, 119],
+                        [119, 120], [112, 121], [121, 122], [122, 123],
+                        [123, 124], [112, 125], [125, 126], [126, 127],
+                        [127, 128], [112, 129], [129, 130], [130, 131],
+                        [131, 132]]
+
+            pose_link_color = palette[[
+                0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16
+            ] + [16, 16, 16, 16, 16, 16] + [
+                0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16,
+                16
+            ] + [
+                0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16,
+                16
+            ]]
+            pose_kpt_color = palette[
+                [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] +
+                [0, 0, 0, 0, 0, 0] + [19] * (68 + 42)]
+
+        elif dataset == 'TopDownAicDataset':
+            skeleton = [[2, 1], [1, 0], [0, 13], [13, 3], [3, 4], [4, 5],
+                        [8, 7], [7, 6], [6, 9], [9, 10], [10, 11], [12, 13],
+                        [0, 6], [3, 9]]
+
+            pose_link_color = palette[[
+                9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7
+            ]]
+            pose_kpt_color = palette[[
+                9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0
+            ]]
+
+        elif dataset == 'TopDownMpiiDataset':
+            skeleton = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7],
+                        [7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13],
+                        [13, 14], [14, 15]]
+
+            pose_link_color = palette[[
+                16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9
+            ]]
+            pose_kpt_color = palette[[
+                16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9
+            ]]
+
+        elif dataset == 'TopDownMpiiTrbDataset':
+            skeleton = [[12, 13], [13, 0], [13, 1], [0, 2], [1, 3], [2, 4],
+                        [3, 5], [0, 6], [1, 7], [6, 7], [6, 8], [7,
+                                                                 9], [8, 10],
+                        [9, 11], [14, 15], [16, 17], [18, 19], [20, 21],
+                        [22, 23], [24, 25], [26, 27], [28, 29], [30, 31],
+                        [32, 33], [34, 35], [36, 37], [38, 39]]
+
+            pose_link_color = palette[[16] * 14 + [19] * 13]
+            pose_kpt_color = palette[[16] * 14 + [0] * 26]
+
+        elif dataset in ('OneHand10KDataset', 'FreiHandDataset',
+                         'PanopticDataset'):
+            skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7],
+                        [7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13],
+                        [13, 14], [14, 15], [15, 16], [0, 17], [17, 18],
+                        [18, 19], [19, 20]]
+
+            pose_link_color = palette[[
+                0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16,
+                16
+            ]]
+            pose_kpt_color = palette[[
+                0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16,
+                16, 16
+            ]]
+
+        elif dataset == 'InterHand2DDataset':
+            skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9],
+                        [9, 10], [10, 11], [12, 13], [13, 14], [14, 15],
+                        [16, 17], [17, 18], [18, 19], [3, 20], [7, 20],
+                        [11, 20], [15, 20], [19, 20]]
+
+            pose_link_color = palette[[
+                0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12,
+                16
+            ]]
+            pose_kpt_color = palette[[
+                0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16,
+                16, 0
+            ]]
+
+        elif dataset == 'Face300WDataset':
+            # show the results
+            skeleton = []
+
+            pose_link_color = palette[[]]
+            pose_kpt_color = palette[[19] * 68]
+            kpt_score_thr = 0
+
+        elif dataset == 'FaceAFLWDataset':
+            # show the results
+            skeleton = []
+
+            pose_link_color = palette[[]]
+            pose_kpt_color = palette[[19] * 19]
+            kpt_score_thr = 0
+
+        elif dataset == 'FaceCOFWDataset':
+            # show the results
+            skeleton = []
+
+            pose_link_color = palette[[]]
+            pose_kpt_color = palette[[19] * 29]
+            kpt_score_thr = 0
+
+        elif dataset == 'FaceWFLWDataset':
+            # show the results
+            skeleton = []
+
+            pose_link_color = palette[[]]
+            pose_kpt_color = palette[[19] * 98]
+            kpt_score_thr = 0
+
+        elif dataset == 'AnimalHorse10Dataset':
+            skeleton = [[0, 1], [1, 12], [12, 16], [16, 21], [21, 17],
+                        [17, 11], [11, 10], [10, 8], [8, 9], [9, 12], [2, 3],
+                        [3, 4], [5, 6], [6, 7], [13, 14], [14, 15], [18, 19],
+                        [19, 20]]
+
+            pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 +
+                                      [7] * 2]
+            pose_kpt_color = palette[[
+                4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7,
+                4
+            ]]
+
+        elif dataset == 'AnimalFlyDataset':
+            skeleton = [[1, 0], [2, 0], [3, 0], [4, 3], [5, 4], [7, 6], [8, 7],
+                        [9, 8], [11, 10], [12, 11], [13, 12], [15, 14],
+                        [16, 15], [17, 16], [19, 18], [20, 19], [21, 20],
+                        [23, 22], [24, 23], [25, 24], [27, 26], [28, 27],
+                        [29, 28], [30, 3], [31, 3]]
+
+            pose_link_color = palette[[0] * 25]
+            pose_kpt_color = palette[[0] * 32]
+
+        elif dataset == 'AnimalLocustDataset':
+            skeleton = [[1, 0], [2, 1], [3, 2], [4, 3], [6, 5], [7, 6], [9, 8],
+                        [10, 9], [11, 10], [13, 12], [14, 13], [15, 14],
+                        [17, 16], [18, 17], [19, 18], [21, 20], [22, 21],
+                        [24, 23], [25, 24], [26, 25], [28, 27], [29, 28],
+                        [30, 29], [32, 31], [33, 32], [34, 33]]
+
+            pose_link_color = palette[[0] * 26]
+            pose_kpt_color = palette[[0] * 35]
+
+        elif dataset == 'AnimalZebraDataset':
+            skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2],
+                        [8, 7]]
+
+            pose_link_color = palette[[0] * 8]
+            pose_kpt_color = palette[[0] * 9]
+
+        elif dataset in 'AnimalPoseDataset':
+            skeleton = [[0, 1], [0, 2], [1, 3], [0, 4], [1, 4], [4, 5], [5, 7],
+                        [6, 7], [5, 8], [8, 12], [12, 16], [5, 9], [9, 13],
+                        [13, 17], [6, 10], [10, 14], [14, 18], [6, 11],
+                        [11, 15], [15, 19]]
+
+            pose_link_color = palette[[0] * 20]
+            pose_kpt_color = palette[[0] * 20]
+        else:
+            NotImplementedError()
+
+    if hasattr(model, 'module'):
+        model = model.module
+
+    img = model.show_result(
+        img,
+        result,
+        skeleton,
+        radius=radius,
+        thickness=thickness,
+        pose_kpt_color=pose_kpt_color,
+        pose_link_color=pose_link_color,
+        kpt_score_thr=kpt_score_thr,
+        bbox_color=bbox_color,
+        show=show,
+        out_file=out_file)
+
+    return img
+
+
+def process_mmdet_results(mmdet_results, cat_id=1):
+    """Process mmdet results, and return a list of bboxes.
+
+    Args:
+        mmdet_results (list|tuple): mmdet results.
+        cat_id (int): category id (default: 1 for human)
+
+    Returns:
+        person_results (list): a list of detected bounding boxes
+    """
+    if isinstance(mmdet_results, tuple):
+        det_results = mmdet_results[0]
+    else:
+        det_results = mmdet_results
+
+    bboxes = det_results[cat_id - 1]
+
+    person_results = []
+    for bbox in bboxes:
+        person = {}
+        person['bbox'] = bbox
+        person_results.append(person)
+
+    return person_results