Switch to side-by-side view

--- a
+++ b/ViTPose/mmpose/apis/inference_3d.py
@@ -0,0 +1,791 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import numpy as np
+import torch
+from mmcv.parallel import collate, scatter
+
+from mmpose.datasets.pipelines import Compose
+from .inference import _box2cs, _xywh2xyxy, _xyxy2xywh
+
+
+def extract_pose_sequence(pose_results, frame_idx, causal, seq_len, step=1):
+    """Extract the target frame from 2D pose results, and pad the sequence to a
+    fixed length.
+
+    Args:
+        pose_results (list[list[dict]]): Multi-frame pose detection results
+            stored in a nested list. Each element of the outer list is the
+            pose detection results of a single frame, and each element of the
+            inner list is the pose information of one person, which contains:
+
+                - keypoints (ndarray[K, 2 or 3]): x, y, [score]
+                - track_id (int): unique id of each person, required \
+                    when ``with_track_id==True``.
+                - bbox ((4, ) or (5, )): left, right, top, bottom, [score]
+
+        frame_idx (int): The index of the frame in the original video.
+        causal (bool): If True, the target frame is the last frame in
+            a sequence. Otherwise, the target frame is in the middle of
+            a sequence.
+        seq_len (int): The number of frames in the input sequence.
+        step (int): Step size to extract frames from the video.
+
+    Returns:
+        list[list[dict]]: Multi-frame pose detection results stored \
+            in a nested list with a length of seq_len.
+    """
+
+    if causal:
+        frames_left = seq_len - 1
+        frames_right = 0
+    else:
+        frames_left = (seq_len - 1) // 2
+        frames_right = frames_left
+    num_frames = len(pose_results)
+
+    # get the padded sequence
+    pad_left = max(0, frames_left - frame_idx // step)
+    pad_right = max(0, frames_right - (num_frames - 1 - frame_idx) // step)
+    start = max(frame_idx % step, frame_idx - frames_left * step)
+    end = min(num_frames - (num_frames - 1 - frame_idx) % step,
+              frame_idx + frames_right * step + 1)
+    pose_results_seq = [pose_results[0]] * pad_left + \
+        pose_results[start:end:step] + [pose_results[-1]] * pad_right
+    return pose_results_seq
+
+
+def _gather_pose_lifter_inputs(pose_results,
+                               bbox_center,
+                               bbox_scale,
+                               norm_pose_2d=False):
+    """Gather input data (keypoints and track_id) for pose lifter model.
+
+    Note:
+        - The temporal length of the pose detection results: T
+        - The number of the person instances: N
+        - The number of the keypoints: K
+        - The channel number of each keypoint: C
+
+    Args:
+        pose_results (List[List[Dict]]): Multi-frame pose detection results
+            stored in a nested list. Each element of the outer list is the
+            pose detection results of a single frame, and each element of the
+            inner list is the pose information of one person, which contains:
+
+                - keypoints (ndarray[K, 2 or 3]): x, y, [score]
+                - track_id (int): unique id of each person, required when
+                    ``with_track_id==True```
+                - bbox ((4, ) or (5, )): left, right, top, bottom, [score]
+
+        bbox_center (ndarray[1, 2]): x, y. The average center coordinate of the
+            bboxes in the dataset.
+        bbox_scale (int|float): The average scale of the bboxes in the dataset.
+        norm_pose_2d (bool): If True, scale the bbox (along with the 2D
+            pose) to bbox_scale, and move the bbox (along with the 2D pose) to
+            bbox_center. Default: False.
+
+    Returns:
+        list[list[dict]]: Multi-frame pose detection results
+            stored in a nested list. Each element of the outer list is the
+            pose detection results of a single frame, and each element of the
+            inner list is the pose information of one person, which contains:
+
+                - keypoints (ndarray[K, 2 or 3]): x, y, [score]
+                - track_id (int): unique id of each person, required when
+                    ``with_track_id==True``
+    """
+    sequence_inputs = []
+    for frame in pose_results:
+        frame_inputs = []
+        for res in frame:
+            inputs = dict()
+
+            if norm_pose_2d:
+                bbox = res['bbox']
+                center = np.array([[(bbox[0] + bbox[2]) / 2,
+                                    (bbox[1] + bbox[3]) / 2]])
+                scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
+                inputs['keypoints'] = (res['keypoints'][:, :2] - center) \
+                    / scale * bbox_scale + bbox_center
+            else:
+                inputs['keypoints'] = res['keypoints'][:, :2]
+
+            if res['keypoints'].shape[1] == 3:
+                inputs['keypoints'] = np.concatenate(
+                    [inputs['keypoints'], res['keypoints'][:, 2:]], axis=1)
+
+            if 'track_id' in res:
+                inputs['track_id'] = res['track_id']
+            frame_inputs.append(inputs)
+        sequence_inputs.append(frame_inputs)
+    return sequence_inputs
+
+
+def _collate_pose_sequence(pose_results, with_track_id=True, target_frame=-1):
+    """Reorganize multi-frame pose detection results into individual pose
+    sequences.
+
+    Note:
+        - The temporal length of the pose detection results: T
+        - The number of the person instances: N
+        - The number of the keypoints: K
+        - The channel number of each keypoint: C
+
+    Args:
+        pose_results (List[List[Dict]]): Multi-frame pose detection results
+            stored in a nested list. Each element of the outer list is the
+            pose detection results of a single frame, and each element of the
+            inner list is the pose information of one person, which contains:
+
+                - keypoints (ndarray[K, 2 or 3]): x, y, [score]
+                - track_id (int): unique id of each person, required when
+                    ``with_track_id==True```
+
+        with_track_id (bool): If True, the element in pose_results is expected
+            to contain "track_id", which will be used to gather the pose
+            sequence of a person from multiple frames. Otherwise, the pose
+            results in each frame are expected to have a consistent number and
+            order of identities. Default is True.
+        target_frame (int): The index of the target frame. Default: -1.
+    """
+    T = len(pose_results)
+    assert T > 0
+
+    target_frame = (T + target_frame) % T  # convert negative index to positive
+
+    N = len(pose_results[target_frame])  # use identities in the target frame
+    if N == 0:
+        return []
+
+    K, C = pose_results[target_frame][0]['keypoints'].shape
+
+    track_ids = None
+    if with_track_id:
+        track_ids = [res['track_id'] for res in pose_results[target_frame]]
+
+    pose_sequences = []
+    for idx in range(N):
+        pose_seq = dict()
+        # gather static information
+        for k, v in pose_results[target_frame][idx].items():
+            if k != 'keypoints':
+                pose_seq[k] = v
+        # gather keypoints
+        if not with_track_id:
+            pose_seq['keypoints'] = np.stack(
+                [frame[idx]['keypoints'] for frame in pose_results])
+        else:
+            keypoints = np.zeros((T, K, C), dtype=np.float32)
+            keypoints[target_frame] = pose_results[target_frame][idx][
+                'keypoints']
+            # find the left most frame containing track_ids[idx]
+            for frame_idx in range(target_frame - 1, -1, -1):
+                contains_idx = False
+                for res in pose_results[frame_idx]:
+                    if res['track_id'] == track_ids[idx]:
+                        keypoints[frame_idx] = res['keypoints']
+                        contains_idx = True
+                        break
+                if not contains_idx:
+                    # replicate the left most frame
+                    keypoints[:frame_idx + 1] = keypoints[frame_idx + 1]
+                    break
+            # find the right most frame containing track_idx[idx]
+            for frame_idx in range(target_frame + 1, T):
+                contains_idx = False
+                for res in pose_results[frame_idx]:
+                    if res['track_id'] == track_ids[idx]:
+                        keypoints[frame_idx] = res['keypoints']
+                        contains_idx = True
+                        break
+                if not contains_idx:
+                    # replicate the right most frame
+                    keypoints[frame_idx + 1:] = keypoints[frame_idx]
+                    break
+            pose_seq['keypoints'] = keypoints
+        pose_sequences.append(pose_seq)
+
+    return pose_sequences
+
+
+def inference_pose_lifter_model(model,
+                                pose_results_2d,
+                                dataset=None,
+                                dataset_info=None,
+                                with_track_id=True,
+                                image_size=None,
+                                norm_pose_2d=False):
+    """Inference 3D pose from 2D pose sequences using a pose lifter model.
+
+    Args:
+        model (nn.Module): The loaded pose lifter model
+        pose_results_2d (list[list[dict]]): The 2D pose sequences stored in a
+            nested list. Each element of the outer list is the 2D pose results
+            of a single frame, and each element of the inner list is the 2D
+            pose of one person, which contains:
+
+            - "keypoints" (ndarray[K, 2 or 3]): x, y, [score]
+            - "track_id" (int)
+        dataset (str): Dataset name, e.g. 'Body3DH36MDataset'
+        with_track_id: If True, the element in pose_results_2d is expected to
+            contain "track_id", which will be used to gather the pose sequence
+            of a person from multiple frames. Otherwise, the pose results in
+            each frame are expected to have a consistent number and order of
+            identities. Default is True.
+        image_size (tuple|list): image width, image height. If None, image size
+            will not be contained in dict ``data``.
+        norm_pose_2d (bool): If True, scale the bbox (along with the 2D
+            pose) to the average bbox scale of the dataset, and move the bbox
+            (along with the 2D pose) to the average bbox center of the dataset.
+
+    Returns:
+        list[dict]: 3D pose inference results. Each element is the result of \
+            an instance, which contains:
+
+            - "keypoints_3d" (ndarray[K, 3]): predicted 3D keypoints
+            - "keypoints" (ndarray[K, 2 or 3]): from the last frame in \
+                ``pose_results_2d``.
+            - "track_id" (int): from the last frame in ``pose_results_2d``. \
+                If there is no valid instance, an empty list will be \
+                returned.
+    """
+    cfg = model.cfg
+    test_pipeline = Compose(cfg.test_pipeline)
+
+    device = next(model.parameters()).device
+    if device.type == 'cpu':
+        device = -1
+
+    if dataset_info is not None:
+        flip_pairs = dataset_info.flip_pairs
+        assert 'stats_info' in dataset_info._dataset_info
+        bbox_center = dataset_info._dataset_info['stats_info']['bbox_center']
+        bbox_scale = dataset_info._dataset_info['stats_info']['bbox_scale']
+    else:
+        warnings.warn(
+            'dataset is deprecated.'
+            'Please set `dataset_info` in the config.'
+            'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
+            DeprecationWarning)
+        # TODO: These will be removed in the later versions.
+        if dataset == 'Body3DH36MDataset':
+            flip_pairs = [[1, 4], [2, 5], [3, 6], [11, 14], [12, 15], [13, 16]]
+            bbox_center = np.array([[528, 427]], dtype=np.float32)
+            bbox_scale = 400
+        else:
+            raise NotImplementedError()
+
+    target_idx = -1 if model.causal else len(pose_results_2d) // 2
+    pose_lifter_inputs = _gather_pose_lifter_inputs(pose_results_2d,
+                                                    bbox_center, bbox_scale,
+                                                    norm_pose_2d)
+    pose_sequences_2d = _collate_pose_sequence(pose_lifter_inputs,
+                                               with_track_id, target_idx)
+
+    if not pose_sequences_2d:
+        return []
+
+    batch_data = []
+    for seq in pose_sequences_2d:
+        pose_2d = seq['keypoints'].astype(np.float32)
+        T, K, C = pose_2d.shape
+
+        input_2d = pose_2d[..., :2]
+        input_2d_visible = pose_2d[..., 2:3]
+        if C > 2:
+            input_2d_visible = pose_2d[..., 2:3]
+        else:
+            input_2d_visible = np.ones((T, K, 1), dtype=np.float32)
+
+        # TODO: Will be removed in the later versions
+        # Dummy 3D input
+        # This is for compatibility with configs in mmpose<=v0.14.0, where a
+        # 3D input is required to generate denormalization parameters. This
+        # part will be removed in the future.
+        target = np.zeros((K, 3), dtype=np.float32)
+        target_visible = np.ones((K, 1), dtype=np.float32)
+
+        # Dummy image path
+        # This is for compatibility with configs in mmpose<=v0.14.0, where
+        # target_image_path is required. This part will be removed in the
+        # future.
+        target_image_path = None
+
+        data = {
+            'input_2d': input_2d,
+            'input_2d_visible': input_2d_visible,
+            'target': target,
+            'target_visible': target_visible,
+            'target_image_path': target_image_path,
+            'ann_info': {
+                'num_joints': K,
+                'flip_pairs': flip_pairs
+            }
+        }
+
+        if image_size is not None:
+            assert len(image_size) == 2
+            data['image_width'] = image_size[0]
+            data['image_height'] = image_size[1]
+
+        data = test_pipeline(data)
+        batch_data.append(data)
+
+    batch_data = collate(batch_data, samples_per_gpu=len(batch_data))
+    batch_data = scatter(batch_data, target_gpus=[device])[0]
+
+    with torch.no_grad():
+        result = model(
+            input=batch_data['input'],
+            metas=batch_data['metas'],
+            return_loss=False)
+
+    poses_3d = result['preds']
+    if poses_3d.shape[-1] != 4:
+        assert poses_3d.shape[-1] == 3
+        dummy_score = np.ones(
+            poses_3d.shape[:-1] + (1, ), dtype=poses_3d.dtype)
+        poses_3d = np.concatenate((poses_3d, dummy_score), axis=-1)
+    pose_results = []
+    for pose_2d, pose_3d in zip(pose_sequences_2d, poses_3d):
+        pose_result = pose_2d.copy()
+        pose_result['keypoints_3d'] = pose_3d
+        pose_results.append(pose_result)
+
+    return pose_results
+
+
+def vis_3d_pose_result(model,
+                       result,
+                       img=None,
+                       dataset='Body3DH36MDataset',
+                       dataset_info=None,
+                       kpt_score_thr=0.3,
+                       radius=8,
+                       thickness=2,
+                       num_instances=-1,
+                       show=False,
+                       out_file=None):
+    """Visualize the 3D pose estimation results.
+
+    Args:
+        model (nn.Module): The loaded model.
+        result (list[dict])
+    """
+
+    if dataset_info is not None:
+        skeleton = dataset_info.skeleton
+        pose_kpt_color = dataset_info.pose_kpt_color
+        pose_link_color = dataset_info.pose_link_color
+    else:
+        warnings.warn(
+            'dataset is deprecated.'
+            'Please set `dataset_info` in the config.'
+            'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
+            DeprecationWarning)
+        # TODO: These will be removed in the later versions.
+        palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
+                            [230, 230, 0], [255, 153, 255], [153, 204, 255],
+                            [255, 102, 255], [255, 51, 255], [102, 178, 255],
+                            [51, 153, 255], [255, 153, 153], [255, 102, 102],
+                            [255, 51, 51], [153, 255, 153], [102, 255, 102],
+                            [51, 255, 51], [0, 255, 0], [0, 0, 255],
+                            [255, 0, 0], [255, 255, 255]])
+
+        if dataset == 'Body3DH36MDataset':
+            skeleton = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7],
+                        [7, 8], [8, 9], [9, 10], [8, 11], [11, 12], [12, 13],
+                        [8, 14], [14, 15], [15, 16]]
+
+            pose_kpt_color = palette[[
+                9, 0, 0, 0, 16, 16, 16, 9, 9, 9, 9, 16, 16, 16, 0, 0, 0
+            ]]
+            pose_link_color = palette[[
+                0, 0, 0, 16, 16, 16, 9, 9, 9, 9, 16, 16, 16, 0, 0, 0
+            ]]
+
+        elif dataset == 'InterHand3DDataset':
+            skeleton = [[0, 1], [1, 2], [2, 3], [3, 20], [4, 5], [5, 6],
+                        [6, 7], [7, 20], [8, 9], [9, 10], [10, 11], [11, 20],
+                        [12, 13], [13, 14], [14, 15], [15, 20], [16, 17],
+                        [17, 18], [18, 19], [19, 20], [21, 22], [22, 23],
+                        [23, 24], [24, 41], [25, 26], [26, 27], [27, 28],
+                        [28, 41], [29, 30], [30, 31], [31, 32], [32, 41],
+                        [33, 34], [34, 35], [35, 36], [36, 41], [37, 38],
+                        [38, 39], [39, 40], [40, 41]]
+
+            pose_kpt_color = [[14, 128, 250], [14, 128, 250], [14, 128, 250],
+                              [14, 128, 250], [80, 127, 255], [80, 127, 255],
+                              [80, 127, 255], [80, 127, 255], [71, 99, 255],
+                              [71, 99, 255], [71, 99, 255], [71, 99, 255],
+                              [0, 36, 255], [0, 36, 255], [0, 36, 255],
+                              [0, 36, 255], [0, 0, 230], [0, 0, 230],
+                              [0, 0, 230], [0, 0, 230], [0, 0, 139],
+                              [237, 149, 100], [237, 149, 100],
+                              [237, 149, 100], [237, 149, 100], [230, 128, 77],
+                              [230, 128, 77], [230, 128, 77], [230, 128, 77],
+                              [255, 144, 30], [255, 144, 30], [255, 144, 30],
+                              [255, 144, 30], [153, 51, 0], [153, 51, 0],
+                              [153, 51, 0], [153, 51, 0], [255, 51, 13],
+                              [255, 51, 13], [255, 51, 13], [255, 51, 13],
+                              [103, 37, 8]]
+
+            pose_link_color = [[14, 128, 250], [14, 128, 250], [14, 128, 250],
+                               [14, 128, 250], [80, 127, 255], [80, 127, 255],
+                               [80, 127, 255], [80, 127, 255], [71, 99, 255],
+                               [71, 99, 255], [71, 99, 255], [71, 99, 255],
+                               [0, 36, 255], [0, 36, 255], [0, 36, 255],
+                               [0, 36, 255], [0, 0, 230], [0, 0, 230],
+                               [0, 0, 230], [0, 0, 230], [237, 149, 100],
+                               [237, 149, 100], [237, 149, 100],
+                               [237, 149, 100], [230, 128, 77], [230, 128, 77],
+                               [230, 128, 77], [230, 128, 77], [255, 144, 30],
+                               [255, 144, 30], [255, 144, 30], [255, 144, 30],
+                               [153, 51, 0], [153, 51, 0], [153, 51, 0],
+                               [153, 51, 0], [255, 51, 13], [255, 51, 13],
+                               [255, 51, 13], [255, 51, 13]]
+        else:
+            raise NotImplementedError
+
+    if hasattr(model, 'module'):
+        model = model.module
+
+    img = model.show_result(
+        result,
+        img,
+        skeleton,
+        radius=radius,
+        thickness=thickness,
+        pose_kpt_color=pose_kpt_color,
+        pose_link_color=pose_link_color,
+        num_instances=num_instances,
+        show=show,
+        out_file=out_file)
+
+    return img
+
+
+def inference_interhand_3d_model(model,
+                                 img_or_path,
+                                 det_results,
+                                 bbox_thr=None,
+                                 format='xywh',
+                                 dataset='InterHand3DDataset'):
+    """Inference a single image with a list of hand bounding boxes.
+
+    Note:
+        - num_bboxes: N
+        - num_keypoints: K
+
+    Args:
+        model (nn.Module): The loaded pose model.
+        img_or_path (str | np.ndarray): Image filename or loaded image.
+        det_results (list[dict]): The 2D bbox sequences stored in a list.
+            Each each element of the list is the bbox of one person, whose
+            shape is (ndarray[4 or 5]), containing 4 box coordinates
+            (and score).
+        dataset (str): Dataset name.
+        format: bbox format ('xyxy' | 'xywh'). Default: 'xywh'.
+            'xyxy' means (left, top, right, bottom),
+            'xywh' means (left, top, width, height).
+
+    Returns:
+        list[dict]: 3D pose inference results. Each element is the result \
+            of an instance, which contains the predicted 3D keypoints with \
+            shape (ndarray[K,3]). If there is no valid instance, an \
+            empty list will be returned.
+    """
+
+    assert format in ['xyxy', 'xywh']
+
+    pose_results = []
+
+    if len(det_results) == 0:
+        return pose_results
+
+    # Change for-loop preprocess each bbox to preprocess all bboxes at once.
+    bboxes = np.array([box['bbox'] for box in det_results])
+
+    # Select bboxes by score threshold
+    if bbox_thr is not None:
+        assert bboxes.shape[1] == 5
+        valid_idx = np.where(bboxes[:, 4] > bbox_thr)[0]
+        bboxes = bboxes[valid_idx]
+        det_results = [det_results[i] for i in valid_idx]
+
+    if format == 'xyxy':
+        bboxes_xyxy = bboxes
+        bboxes_xywh = _xyxy2xywh(bboxes)
+    else:
+        # format is already 'xywh'
+        bboxes_xywh = bboxes
+        bboxes_xyxy = _xywh2xyxy(bboxes)
+
+    # if bbox_thr remove all bounding box
+    if len(bboxes_xywh) == 0:
+        return []
+
+    cfg = model.cfg
+    device = next(model.parameters()).device
+    if device.type == 'cpu':
+        device = -1
+
+    # build the data pipeline
+    test_pipeline = Compose(cfg.test_pipeline)
+
+    assert len(bboxes[0]) in [4, 5]
+
+    if dataset == 'InterHand3DDataset':
+        flip_pairs = [[i, 21 + i] for i in range(21)]
+    else:
+        raise NotImplementedError()
+
+    batch_data = []
+    for bbox in bboxes:
+        center, scale = _box2cs(cfg, bbox)
+
+        # prepare data
+        data = {
+            'center':
+            center,
+            'scale':
+            scale,
+            'bbox_score':
+            bbox[4] if len(bbox) == 5 else 1,
+            'bbox_id':
+            0,  # need to be assigned if batch_size > 1
+            'dataset':
+            dataset,
+            'joints_3d':
+            np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
+            'joints_3d_visible':
+            np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
+            'rotation':
+            0,
+            'ann_info': {
+                'image_size': np.array(cfg.data_cfg['image_size']),
+                'num_joints': cfg.data_cfg['num_joints'],
+                'flip_pairs': flip_pairs,
+                'heatmap3d_depth_bound': cfg.data_cfg['heatmap3d_depth_bound'],
+                'heatmap_size_root': cfg.data_cfg['heatmap_size_root'],
+                'root_depth_bound': cfg.data_cfg['root_depth_bound']
+            }
+        }
+
+        if isinstance(img_or_path, np.ndarray):
+            data['img'] = img_or_path
+        else:
+            data['image_file'] = img_or_path
+
+        data = test_pipeline(data)
+        batch_data.append(data)
+
+    batch_data = collate(batch_data, samples_per_gpu=len(batch_data))
+    batch_data = scatter(batch_data, [device])[0]
+
+    # forward the model
+    with torch.no_grad():
+        result = model(
+            img=batch_data['img'],
+            img_metas=batch_data['img_metas'],
+            return_loss=False)
+
+    poses_3d = result['preds']
+    rel_root_depth = result['rel_root_depth']
+    hand_type = result['hand_type']
+    if poses_3d.shape[-1] != 4:
+        assert poses_3d.shape[-1] == 3
+        dummy_score = np.ones(
+            poses_3d.shape[:-1] + (1, ), dtype=poses_3d.dtype)
+        poses_3d = np.concatenate((poses_3d, dummy_score), axis=-1)
+
+    # add relative root depth to left hand joints
+    poses_3d[:, 21:, 2] += rel_root_depth
+
+    # set joint scores according to hand type
+    poses_3d[:, :21, 3] *= hand_type[:, [0]]
+    poses_3d[:, 21:, 3] *= hand_type[:, [1]]
+
+    pose_results = []
+    for pose_3d, person_res, bbox_xyxy in zip(poses_3d, det_results,
+                                              bboxes_xyxy):
+        pose_res = person_res.copy()
+        pose_res['keypoints_3d'] = pose_3d
+        pose_res['bbox'] = bbox_xyxy
+        pose_results.append(pose_res)
+
+    return pose_results
+
+
+def inference_mesh_model(model,
+                         img_or_path,
+                         det_results,
+                         bbox_thr=None,
+                         format='xywh',
+                         dataset='MeshH36MDataset'):
+    """Inference a single image with a list of bounding boxes.
+
+    Note:
+        - num_bboxes: N
+        - num_keypoints: K
+        - num_vertices: V
+        - num_faces: F
+
+    Args:
+        model (nn.Module): The loaded pose model.
+        img_or_path (str | np.ndarray): Image filename or loaded image.
+        det_results (list[dict]): The 2D bbox sequences stored in a list.
+            Each element of the list is the bbox of one person.
+            "bbox" (ndarray[4 or 5]): The person bounding box,
+            which contains 4 box coordinates (and score).
+        bbox_thr (float | None): Threshold for bounding boxes.
+            Only bboxes with higher scores will be fed into the pose
+            detector. If bbox_thr is None, all boxes will be used.
+        format (str): bbox format ('xyxy' | 'xywh'). Default: 'xywh'.
+
+            - 'xyxy' means (left, top, right, bottom),
+            - 'xywh' means (left, top, width, height).
+        dataset (str): Dataset name.
+
+    Returns:
+        list[dict]: 3D pose inference results. Each element \
+            is the result of an instance, which contains:
+
+            - 'bbox' (ndarray[4]): instance bounding bbox
+            - 'center' (ndarray[2]): bbox center
+            - 'scale' (ndarray[2]): bbox scale
+            - 'keypoints_3d' (ndarray[K,3]): predicted 3D keypoints
+            - 'camera' (ndarray[3]): camera parameters
+            - 'vertices' (ndarray[V, 3]): predicted 3D vertices
+            - 'faces' (ndarray[F, 3]): mesh faces
+
+            If there is no valid instance, an empty list
+            will be returned.
+    """
+
+    assert format in ['xyxy', 'xywh']
+
+    pose_results = []
+
+    if len(det_results) == 0:
+        return pose_results
+
+    # Change for-loop preprocess each bbox to preprocess all bboxes at once.
+    bboxes = np.array([box['bbox'] for box in det_results])
+
+    # Select bboxes by score threshold
+    if bbox_thr is not None:
+        assert bboxes.shape[1] == 5
+        valid_idx = np.where(bboxes[:, 4] > bbox_thr)[0]
+        bboxes = bboxes[valid_idx]
+        det_results = [det_results[i] for i in valid_idx]
+
+    if format == 'xyxy':
+        bboxes_xyxy = bboxes
+        bboxes_xywh = _xyxy2xywh(bboxes)
+    else:
+        # format is already 'xywh'
+        bboxes_xywh = bboxes
+        bboxes_xyxy = _xywh2xyxy(bboxes)
+
+    # if bbox_thr remove all bounding box
+    if len(bboxes_xywh) == 0:
+        return []
+
+    cfg = model.cfg
+    device = next(model.parameters()).device
+    if device.type == 'cpu':
+        device = -1
+
+    # build the data pipeline
+    test_pipeline = Compose(cfg.test_pipeline)
+
+    assert len(bboxes[0]) in [4, 5]
+
+    if dataset == 'MeshH36MDataset':
+        flip_pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9],
+                      [20, 21], [22, 23]]
+    else:
+        raise NotImplementedError()
+
+    batch_data = []
+    for bbox in bboxes:
+        center, scale = _box2cs(cfg, bbox)
+
+        # prepare data
+        data = {
+            'image_file':
+            img_or_path,
+            'center':
+            center,
+            'scale':
+            scale,
+            'rotation':
+            0,
+            'bbox_score':
+            bbox[4] if len(bbox) == 5 else 1,
+            'dataset':
+            dataset,
+            'joints_2d':
+            np.zeros((cfg.data_cfg.num_joints, 2), dtype=np.float32),
+            'joints_2d_visible':
+            np.zeros((cfg.data_cfg.num_joints, 1), dtype=np.float32),
+            'joints_3d':
+            np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
+            'joints_3d_visible':
+            np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
+            'pose':
+            np.zeros(72, dtype=np.float32),
+            'beta':
+            np.zeros(10, dtype=np.float32),
+            'has_smpl':
+            0,
+            'ann_info': {
+                'image_size': np.array(cfg.data_cfg['image_size']),
+                'num_joints': cfg.data_cfg['num_joints'],
+                'flip_pairs': flip_pairs,
+            }
+        }
+
+        data = test_pipeline(data)
+        batch_data.append(data)
+
+    batch_data = collate(batch_data, samples_per_gpu=len(batch_data))
+    batch_data = scatter(batch_data, target_gpus=[device])[0]
+
+    # forward the model
+    with torch.no_grad():
+        preds = model(
+            img=batch_data['img'],
+            img_metas=batch_data['img_metas'],
+            return_loss=False,
+            return_vertices=True,
+            return_faces=True)
+
+    for idx in range(len(det_results)):
+        pose_res = det_results[idx].copy()
+        pose_res['bbox'] = bboxes_xyxy[idx]
+        pose_res['center'] = batch_data['img_metas'][idx]['center']
+        pose_res['scale'] = batch_data['img_metas'][idx]['scale']
+        pose_res['keypoints_3d'] = preds['keypoints_3d'][idx]
+        pose_res['camera'] = preds['camera'][idx]
+        pose_res['vertices'] = preds['vertices'][idx]
+        pose_res['faces'] = preds['faces']
+        pose_results.append(pose_res)
+    return pose_results
+
+
+def vis_3d_mesh_result(model, result, img=None, show=False, out_file=None):
+    """Visualize the 3D mesh estimation results.
+
+    Args:
+        model (nn.Module): The loaded model.
+        result (list[dict]): 3D mesh estimation results.
+    """
+    if hasattr(model, 'module'):
+        model = model.module
+
+    img = model.show_result(result, img, show=show, out_file=out_file)
+
+    return img