--- a +++ b/mmaction/localization/ssn_utils.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import groupby + +import numpy as np + +from ..core import average_precision_at_temporal_iou +from . import temporal_iou + + +def load_localize_proposal_file(filename): + """Load the proposal file and split it into many parts which contain one + video's information separately. + + Args: + filename(str): Path to the proposal file. + + Returns: + list: List of all videos' information. + """ + lines = list(open(filename)) + + # Split the proposal file into many parts which contain one video's + # information separately. + groups = groupby(lines, lambda x: x.startswith('#')) + + video_infos = [[x.strip() for x in list(g)] for k, g in groups if not k] + + def parse_group(video_info): + """Parse the video's information. + + Template information of a video in a standard file: + # index + video_id + num_frames + fps + num_gts + label, start_frame, end_frame + label, start_frame, end_frame + ... + num_proposals + label, best_iou, overlap_self, start_frame, end_frame + label, best_iou, overlap_self, start_frame, end_frame + ... + + Example of a standard annotation file: + + .. code-block:: txt + + # 0 + video_validation_0000202 + 5666 + 1 + 3 + 8 130 185 + 8 832 1136 + 8 1303 1381 + 5 + 8 0.0620 0.0620 790 5671 + 8 0.1656 0.1656 790 2619 + 8 0.0833 0.0833 3945 5671 + 8 0.0960 0.0960 4173 5671 + 8 0.0614 0.0614 3327 5671 + + Args: + video_info (list): Information of the video. + + Returns: + tuple[str, int, list, list]: + video_id (str): Name of the video. + num_frames (int): Number of frames in the video. + gt_boxes (list): List of the information of gt boxes. + proposal_boxes (list): List of the information of + proposal boxes. + """ + offset = 0 + video_id = video_info[offset] + offset += 1 + + num_frames = int(float(video_info[1]) * float(video_info[2])) + num_gts = int(video_info[3]) + offset = 4 + + gt_boxes = [x.split() for x in video_info[offset:offset + num_gts]] + offset += num_gts + num_proposals = int(video_info[offset]) + offset += 1 + proposal_boxes = [ + x.split() for x in video_info[offset:offset + num_proposals] + ] + + return video_id, num_frames, gt_boxes, proposal_boxes + + return [parse_group(video_info) for video_info in video_infos] + + +def perform_regression(detections): + """Perform regression on detection results. + + Args: + detections (list): Detection results before regression. + + Returns: + list: Detection results after regression. + """ + starts = detections[:, 0] + ends = detections[:, 1] + centers = (starts + ends) / 2 + durations = ends - starts + + new_centers = centers + durations * detections[:, 3] + new_durations = durations * np.exp(detections[:, 4]) + + new_detections = np.concatenate( + (np.clip(new_centers - new_durations / 2, 0, + 1)[:, None], np.clip(new_centers + new_durations / 2, 0, + 1)[:, None], detections[:, 2:]), + axis=1) + return new_detections + + +def temporal_nms(detections, threshold): + """Parse the video's information. + + Args: + detections (list): Detection results before NMS. + threshold (float): Threshold of NMS. + + Returns: + list: Detection results after NMS. + """ + starts = detections[:, 0] + ends = detections[:, 1] + scores = detections[:, 2] + + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ious = temporal_iou(starts[order[1:]], ends[order[1:]], starts[i], + ends[i]) + idxs = np.where(ious <= threshold)[0] + order = order[idxs + 1] + + return detections[keep, :] + + +def eval_ap(detections, gt_by_cls, iou_range): + """Evaluate average precisions. + + Args: + detections (dict): Results of detections. + gt_by_cls (dict): Information of groudtruth. + iou_range (list): Ranges of iou. + + Returns: + list: Average precision values of classes at ious. + """ + ap_values = np.zeros((len(detections), len(iou_range))) + + for iou_idx, min_overlap in enumerate(iou_range): + for class_idx, _ in enumerate(detections): + ap = average_precision_at_temporal_iou(gt_by_cls[class_idx], + detections[class_idx], + [min_overlap]) + ap_values[class_idx, iou_idx] = ap + + return ap_values