--- a +++ b/ViTPose/mmpose/apis/inference_tracking.py @@ -0,0 +1,347 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np + +from mmpose.core import OneEuroFilter, oks_iou + + +def _compute_iou(bboxA, bboxB): + """Compute the Intersection over Union (IoU) between two boxes . + + Args: + bboxA (list): The first bbox info (left, top, right, bottom, score). + bboxB (list): The second bbox info (left, top, right, bottom, score). + + Returns: + float: The IoU value. + """ + + x1 = max(bboxA[0], bboxB[0]) + y1 = max(bboxA[1], bboxB[1]) + x2 = min(bboxA[2], bboxB[2]) + y2 = min(bboxA[3], bboxB[3]) + + inter_area = max(0, x2 - x1) * max(0, y2 - y1) + + bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1]) + bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1]) + union_area = float(bboxA_area + bboxB_area - inter_area) + if union_area == 0: + union_area = 1e-5 + warnings.warn('union_area=0 is unexpected') + + iou = inter_area / union_area + + return iou + + +def _track_by_iou(res, results_last, thr): + """Get track id using IoU tracking greedily. + + Args: + res (dict): The bbox & pose results of the person instance. + results_last (list[dict]): The bbox & pose & track_id info of the + last frame (bbox_result, pose_result, track_id). + thr (float): The threshold for iou tracking. + + Returns: + int: The track id for the new person instance. + list[dict]: The bbox & pose & track_id info of the persons + that have not been matched on the last frame. + dict: The matched person instance on the last frame. + """ + + bbox = list(res['bbox']) + + max_iou_score = -1 + max_index = -1 + match_result = {} + for index, res_last in enumerate(results_last): + bbox_last = list(res_last['bbox']) + + iou_score = _compute_iou(bbox, bbox_last) + if iou_score > max_iou_score: + max_iou_score = iou_score + max_index = index + + if max_iou_score > thr: + track_id = results_last[max_index]['track_id'] + match_result = results_last[max_index] + del results_last[max_index] + else: + track_id = -1 + + return track_id, results_last, match_result + + +def _track_by_oks(res, results_last, thr): + """Get track id using OKS tracking greedily. + + Args: + res (dict): The pose results of the person instance. + results_last (list[dict]): The pose & track_id info of the + last frame (pose_result, track_id). + thr (float): The threshold for oks tracking. + + Returns: + int: The track id for the new person instance. + list[dict]: The pose & track_id info of the persons + that have not been matched on the last frame. + dict: The matched person instance on the last frame. + """ + pose = res['keypoints'].reshape((-1)) + area = res['area'] + max_index = -1 + match_result = {} + + if len(results_last) == 0: + return -1, results_last, match_result + + pose_last = np.array( + [res_last['keypoints'].reshape((-1)) for res_last in results_last]) + area_last = np.array([res_last['area'] for res_last in results_last]) + + oks_score = oks_iou(pose, pose_last, area, area_last) + + max_index = np.argmax(oks_score) + + if oks_score[max_index] > thr: + track_id = results_last[max_index]['track_id'] + match_result = results_last[max_index] + del results_last[max_index] + else: + track_id = -1 + + return track_id, results_last, match_result + + +def _get_area(results): + """Get bbox for each person instance on the current frame. + + Args: + results (list[dict]): The pose results of the current frame + (pose_result). + Returns: + list[dict]: The bbox & pose info of the current frame + (bbox_result, pose_result, area). + """ + for result in results: + if 'bbox' in result: + result['area'] = ((result['bbox'][2] - result['bbox'][0]) * + (result['bbox'][3] - result['bbox'][1])) + else: + xmin = np.min( + result['keypoints'][:, 0][result['keypoints'][:, 0] > 0], + initial=1e10) + xmax = np.max(result['keypoints'][:, 0]) + ymin = np.min( + result['keypoints'][:, 1][result['keypoints'][:, 1] > 0], + initial=1e10) + ymax = np.max(result['keypoints'][:, 1]) + result['area'] = (xmax - xmin) * (ymax - ymin) + result['bbox'] = np.array([xmin, ymin, xmax, ymax]) + return results + + +def _temporal_refine(result, match_result, fps=None): + """Refine koypoints using tracked person instance on last frame. + + Args: + results (dict): The pose results of the current frame + (pose_result). + match_result (dict): The pose results of the last frame + (match_result) + Returns: + (array): The person keypoints after refine. + """ + if 'one_euro' in match_result: + result['keypoints'][:, :2] = match_result['one_euro']( + result['keypoints'][:, :2]) + result['one_euro'] = match_result['one_euro'] + else: + result['one_euro'] = OneEuroFilter(result['keypoints'][:, :2], fps=fps) + return result['keypoints'] + + +def get_track_id(results, + results_last, + next_id, + min_keypoints=3, + use_oks=False, + tracking_thr=0.3, + use_one_euro=False, + fps=None): + """Get track id for each person instance on the current frame. + + Args: + results (list[dict]): The bbox & pose results of the current frame + (bbox_result, pose_result). + results_last (list[dict]): The bbox & pose & track_id info of the + last frame (bbox_result, pose_result, track_id). + next_id (int): The track id for the new person instance. + min_keypoints (int): Minimum number of keypoints recognized as person. + default: 3. + use_oks (bool): Flag to using oks tracking. default: False. + tracking_thr (float): The threshold for tracking. + use_one_euro (bool): Option to use one-euro-filter. default: False. + fps (optional): Parameters that d_cutoff + when one-euro-filter is used as a video input + + Returns: + tuple: + - results (list[dict]): The bbox & pose & track_id info of the \ + current frame (bbox_result, pose_result, track_id). + - next_id (int): The track id for the new person instance. + """ + results = _get_area(results) + + if use_oks: + _track = _track_by_oks + else: + _track = _track_by_iou + + for result in results: + track_id, results_last, match_result = _track(result, results_last, + tracking_thr) + if track_id == -1: + if np.count_nonzero(result['keypoints'][:, 1]) > min_keypoints: + result['track_id'] = next_id + next_id += 1 + else: + # If the number of keypoints detected is small, + # delete that person instance. + result['keypoints'][:, 1] = -10 + result['bbox'] *= 0 + result['track_id'] = -1 + else: + result['track_id'] = track_id + if use_one_euro: + result['keypoints'] = _temporal_refine( + result, match_result, fps=fps) + del match_result + + return results, next_id + + +def vis_pose_tracking_result(model, + img, + result, + radius=4, + thickness=1, + kpt_score_thr=0.3, + dataset='TopDownCocoDataset', + dataset_info=None, + show=False, + out_file=None): + """Visualize the pose tracking results on the image. + + Args: + model (nn.Module): The loaded detector. + img (str | np.ndarray): Image filename or loaded image. + result (list[dict]): The results to draw over `img` + (bbox_result, pose_result). + radius (int): Radius of circles. + thickness (int): Thickness of lines. + kpt_score_thr (float): The threshold to visualize the keypoints. + skeleton (list[tuple]): Default None. + show (bool): Whether to show the image. Default True. + out_file (str|None): The filename of the output visualization image. + """ + if hasattr(model, 'module'): + model = model.module + + palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], + [230, 230, 0], [255, 153, 255], [153, 204, 255], + [255, 102, 255], [255, 51, 255], [102, 178, 255], + [51, 153, 255], [255, 153, 153], [255, 102, 102], + [255, 51, 51], [153, 255, 153], [102, 255, 102], + [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], + [255, 255, 255]]) + + if dataset_info is None and dataset is not None: + warnings.warn( + 'dataset is deprecated.' + 'Please set `dataset_info` in the config.' + 'Check https://github.com/open-mmlab/mmpose/pull/663 for details.', + DeprecationWarning) + # TODO: These will be removed in the later versions. + if dataset in ('TopDownCocoDataset', 'BottomUpCocoDataset', + 'TopDownOCHumanDataset'): + kpt_num = 17 + skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], + [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], + [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], + [3, 5], [4, 6]] + + elif dataset == 'TopDownCocoWholeBodyDataset': + kpt_num = 133 + skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], + [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], + [8, 10], [1, 2], [0, 1], [0, 2], + [1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18], + [15, 19], [16, 20], [16, 21], [16, 22], [91, 92], + [92, 93], [93, 94], [94, 95], [91, 96], [96, 97], + [97, 98], [98, 99], [91, 100], [100, 101], [101, 102], + [102, 103], [91, 104], [104, 105], [105, 106], + [106, 107], [91, 108], [108, 109], [109, 110], + [110, 111], [112, 113], [113, 114], [114, 115], + [115, 116], [112, 117], [117, 118], [118, 119], + [119, 120], [112, 121], [121, 122], [122, 123], + [123, 124], [112, 125], [125, 126], [126, 127], + [127, 128], [112, 129], [129, 130], [130, 131], + [131, 132]] + radius = 1 + + elif dataset == 'TopDownAicDataset': + kpt_num = 14 + skeleton = [[2, 1], [1, 0], [0, 13], [13, 3], [3, 4], [4, 5], + [8, 7], [7, 6], [6, 9], [9, 10], [10, 11], [12, 13], + [0, 6], [3, 9]] + + elif dataset == 'TopDownMpiiDataset': + kpt_num = 16 + skeleton = [[0, 1], [1, 2], [2, 6], [6, 3], [3, 4], [4, 5], [6, 7], + [7, 8], [8, 9], [8, 12], [12, 11], [11, 10], [8, 13], + [13, 14], [14, 15]] + + elif dataset in ('OneHand10KDataset', 'FreiHandDataset', + 'PanopticDataset'): + kpt_num = 21 + skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], + [7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13], + [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], + [18, 19], [19, 20]] + + elif dataset == 'InterHand2DDataset': + kpt_num = 21 + skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9], + [9, 10], [10, 11], [12, 13], [13, 14], [14, 15], + [16, 17], [17, 18], [18, 19], [3, 20], [7, 20], + [11, 20], [15, 20], [19, 20]] + + else: + raise NotImplementedError() + + elif dataset_info is not None: + kpt_num = dataset_info.keypoint_num + skeleton = dataset_info.skeleton + + for res in result: + track_id = res['track_id'] + bbox_color = palette[track_id % len(palette)] + pose_kpt_color = palette[[track_id % len(palette)] * kpt_num] + pose_link_color = palette[[track_id % len(palette)] * len(skeleton)] + img = model.show_result( + img, [res], + skeleton, + radius=radius, + thickness=thickness, + pose_kpt_color=pose_kpt_color, + pose_link_color=pose_link_color, + bbox_color=tuple(bbox_color.tolist()), + kpt_score_thr=kpt_score_thr, + show=show, + out_file=out_file) + + return img