--- a +++ b/mmaction/datasets/ava_dataset.py @@ -0,0 +1,392 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import os.path as osp +from collections import defaultdict +from datetime import datetime + +import mmcv +import numpy as np +from mmcv.utils import print_log + +from ..core.evaluation.ava_utils import ava_eval, read_labelmap, results2csv +from ..utils import get_root_logger +from .base import BaseDataset +from .builder import DATASETS + + +@DATASETS.register_module() +class AVADataset(BaseDataset): + """AVA dataset for spatial temporal detection. + + Based on official AVA annotation files, the dataset loads raw frames, + bounding boxes, proposals and applies specified transformations to return + a dict containing the frame tensors and other information. + + This datasets can load information from the following files: + + .. code-block:: txt + + ann_file -> ava_{train, val}_{v2.1, v2.2}.csv + exclude_file -> ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv + label_file -> ava_action_list_{v2.1, v2.2}.pbtxt / + ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt + proposal_file -> ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl + + Particularly, the proposal_file is a pickle file which contains + ``img_key`` (in format of ``{video_id},{timestamp}``). Example of a pickle + file: + + .. code-block:: JSON + + { + ... + '0f39OWEqJ24,0902': + array([[0.011 , 0.157 , 0.655 , 0.983 , 0.998163]]), + '0f39OWEqJ24,0912': + array([[0.054 , 0.088 , 0.91 , 0.998 , 0.068273], + [0.016 , 0.161 , 0.519 , 0.974 , 0.984025], + [0.493 , 0.283 , 0.981 , 0.984 , 0.983621]]), + ... + } + + Args: + ann_file (str): Path to the annotation file like + ``ava_{train, val}_{v2.1, v2.2}.csv``. + exclude_file (str): Path to the excluded timestamp file like + ``ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv``. + pipeline (list[dict | callable]): A sequence of data transforms. + label_file (str): Path to the label file like + ``ava_action_list_{v2.1, v2.2}.pbtxt`` or + ``ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt``. + Default: None. + filename_tmpl (str): Template for each filename. + Default: 'img_{:05}.jpg'. + start_index (int): Specify a start index for frames in consideration of + different filename format. However, when taking videos as input, + it should be set to 0, since frames loaded from videos count + from 0. Default: 0. + proposal_file (str): Path to the proposal file like + ``ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl``. + Default: None. + person_det_score_thr (float): The threshold of person detection scores, + bboxes with scores above the threshold will be used. Default: 0.9. + Note that 0 <= person_det_score_thr <= 1. If no proposal has + detection score larger than the threshold, the one with the largest + detection score will be used. + num_classes (int): The number of classes of the dataset. Default: 81. + (AVA has 80 action classes, another 1-dim is added for potential + usage) + custom_classes (list[int]): A subset of class ids from origin dataset. + Please note that 0 should NOT be selected, and ``num_classes`` + should be equal to ``len(custom_classes) + 1`` + data_prefix (str): Path to a directory where videos are held. + Default: None. + test_mode (bool): Store True when building test or validation dataset. + Default: False. + modality (str): Modality of data. Support 'RGB', 'Flow'. + Default: 'RGB'. + num_max_proposals (int): Max proposals number to store. Default: 1000. + timestamp_start (int): The start point of included timestamps. The + default value is referred from the official website. Default: 902. + timestamp_end (int): The end point of included timestamps. The + default value is referred from the official website. Default: 1798. + """ + + _FPS = 30 + + def __init__(self, + ann_file, + exclude_file, + pipeline, + label_file=None, + filename_tmpl='img_{:05}.jpg', + start_index=0, + proposal_file=None, + person_det_score_thr=0.9, + num_classes=81, + custom_classes=None, + data_prefix=None, + test_mode=False, + modality='RGB', + num_max_proposals=1000, + timestamp_start=900, + timestamp_end=1800): + # since it inherits from `BaseDataset`, some arguments + # should be assigned before performing `load_annotations()` + self.custom_classes = custom_classes + if custom_classes is not None: + assert num_classes == len(custom_classes) + 1 + assert 0 not in custom_classes + _, class_whitelist = read_labelmap(open(label_file)) + assert set(custom_classes).issubset(class_whitelist) + + self.custom_classes = tuple([0] + custom_classes) + self.exclude_file = exclude_file + self.label_file = label_file + self.proposal_file = proposal_file + assert 0 <= person_det_score_thr <= 1, ( + 'The value of ' + 'person_det_score_thr should in [0, 1]. ') + self.person_det_score_thr = person_det_score_thr + self.num_classes = num_classes + self.filename_tmpl = filename_tmpl + self.num_max_proposals = num_max_proposals + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + self.logger = get_root_logger() + super().__init__( + ann_file, + pipeline, + data_prefix, + test_mode, + start_index=start_index, + modality=modality, + num_classes=num_classes) + + if self.proposal_file is not None: + self.proposals = mmcv.load(self.proposal_file) + else: + self.proposals = None + + if not test_mode: + valid_indexes = self.filter_exclude_file() + self.logger.info( + f'{len(valid_indexes)} out of {len(self.video_infos)} ' + f'frames are valid.') + self.video_infos = [self.video_infos[i] for i in valid_indexes] + + def parse_img_record(self, img_records): + """Merge image records of the same entity at the same time. + + Args: + img_records (list[dict]): List of img_records (lines in AVA + annotations). + + Returns: + tuple(list): A tuple consists of lists of bboxes, action labels and + entity_ids + """ + bboxes, labels, entity_ids = [], [], [] + while len(img_records) > 0: + img_record = img_records[0] + num_img_records = len(img_records) + + selected_records = [ + x for x in img_records + if np.array_equal(x['entity_box'], img_record['entity_box']) + ] + + num_selected_records = len(selected_records) + img_records = [ + x for x in img_records if + not np.array_equal(x['entity_box'], img_record['entity_box']) + ] + + assert len(img_records) + num_selected_records == num_img_records + + bboxes.append(img_record['entity_box']) + valid_labels = np.array([ + selected_record['label'] + for selected_record in selected_records + ]) + + # The format can be directly used by BCELossWithLogits + label = np.zeros(self.num_classes, dtype=np.float32) + label[valid_labels] = 1. + + labels.append(label) + entity_ids.append(img_record['entity_id']) + + bboxes = np.stack(bboxes) + labels = np.stack(labels) + entity_ids = np.stack(entity_ids) + return bboxes, labels, entity_ids + + def filter_exclude_file(self): + """Filter out records in the exclude_file.""" + valid_indexes = [] + if self.exclude_file is None: + valid_indexes = list(range(len(self.video_infos))) + else: + exclude_video_infos = [ + x.strip().split(',') for x in open(self.exclude_file) + ] + for i, video_info in enumerate(self.video_infos): + valid_indexes.append(i) + for video_id, timestamp in exclude_video_infos: + if (video_info['video_id'] == video_id + and video_info['timestamp'] == int(timestamp)): + valid_indexes.pop() + break + return valid_indexes + + def load_annotations(self): + """Load AVA annotations.""" + video_infos = [] + records_dict_by_img = defaultdict(list) + with open(self.ann_file, 'r') as fin: + for line in fin: + line_split = line.strip().split(',') + + label = int(line_split[6]) + if self.custom_classes is not None: + if label not in self.custom_classes: + continue + label = self.custom_classes.index(label) + + video_id = line_split[0] + timestamp = int(line_split[1]) + img_key = f'{video_id},{timestamp:04d}' + + entity_box = np.array(list(map(float, line_split[2:6]))) + entity_id = int(line_split[7]) + shot_info = (0, (self.timestamp_end - self.timestamp_start) * + self._FPS) + + video_info = dict( + video_id=video_id, + timestamp=timestamp, + entity_box=entity_box, + label=label, + entity_id=entity_id, + shot_info=shot_info) + records_dict_by_img[img_key].append(video_info) + + for img_key in records_dict_by_img: + video_id, timestamp = img_key.split(',') + bboxes, labels, entity_ids = self.parse_img_record( + records_dict_by_img[img_key]) + ann = dict( + gt_bboxes=bboxes, gt_labels=labels, entity_ids=entity_ids) + frame_dir = video_id + if self.data_prefix is not None: + frame_dir = osp.join(self.data_prefix, frame_dir) + video_info = dict( + frame_dir=frame_dir, + video_id=video_id, + timestamp=int(timestamp), + img_key=img_key, + shot_info=shot_info, + fps=self._FPS, + ann=ann) + video_infos.append(video_info) + + return video_infos + + def prepare_train_frames(self, idx): + """Prepare the frames for training given the index.""" + results = copy.deepcopy(self.video_infos[idx]) + img_key = results['img_key'] + + results['filename_tmpl'] = self.filename_tmpl + results['modality'] = self.modality + results['start_index'] = self.start_index + results['timestamp_start'] = self.timestamp_start + results['timestamp_end'] = self.timestamp_end + + if self.proposals is not None: + if img_key not in self.proposals: + results['proposals'] = np.array([[0, 0, 1, 1]]) + results['scores'] = np.array([1]) + else: + proposals = self.proposals[img_key] + assert proposals.shape[-1] in [4, 5] + if proposals.shape[-1] == 5: + thr = min(self.person_det_score_thr, max(proposals[:, 4])) + positive_inds = (proposals[:, 4] >= thr) + proposals = proposals[positive_inds] + proposals = proposals[:self.num_max_proposals] + results['proposals'] = proposals[:, :4] + results['scores'] = proposals[:, 4] + else: + proposals = proposals[:self.num_max_proposals] + results['proposals'] = proposals + + ann = results.pop('ann') + results['gt_bboxes'] = ann['gt_bboxes'] + results['gt_labels'] = ann['gt_labels'] + results['entity_ids'] = ann['entity_ids'] + + return self.pipeline(results) + + def prepare_test_frames(self, idx): + """Prepare the frames for testing given the index.""" + results = copy.deepcopy(self.video_infos[idx]) + img_key = results['img_key'] + + results['filename_tmpl'] = self.filename_tmpl + results['modality'] = self.modality + results['start_index'] = self.start_index + results['timestamp_start'] = self.timestamp_start + results['timestamp_end'] = self.timestamp_end + + if self.proposals is not None: + if img_key not in self.proposals: + results['proposals'] = np.array([[0, 0, 1, 1]]) + results['scores'] = np.array([1]) + else: + proposals = self.proposals[img_key] + assert proposals.shape[-1] in [4, 5] + if proposals.shape[-1] == 5: + thr = min(self.person_det_score_thr, max(proposals[:, 4])) + positive_inds = (proposals[:, 4] >= thr) + proposals = proposals[positive_inds] + proposals = proposals[:self.num_max_proposals] + results['proposals'] = proposals[:, :4] + results['scores'] = proposals[:, 4] + else: + proposals = proposals[:self.num_max_proposals] + results['proposals'] = proposals + + ann = results.pop('ann') + # Follow the mmdet variable naming style. + results['gt_bboxes'] = ann['gt_bboxes'] + results['gt_labels'] = ann['gt_labels'] + results['entity_ids'] = ann['entity_ids'] + + return self.pipeline(results) + + def dump_results(self, results, out): + """Dump predictions into a csv file.""" + assert out.endswith('csv') + results2csv(self, results, out, self.custom_classes) + + def evaluate(self, + results, + metrics=('mAP', ), + metric_options=None, + logger=None): + """Evaluate the prediction results and report mAP.""" + assert len(metrics) == 1 and metrics[0] == 'mAP', ( + 'For evaluation on AVADataset, you need to use metrics "mAP" ' + 'See https://github.com/open-mmlab/mmaction2/pull/567 ' + 'for more info.') + time_now = datetime.now().strftime('%Y%m%d_%H%M%S') + temp_file = f'AVA_{time_now}_result.csv' + results2csv(self, results, temp_file, self.custom_classes) + + ret = {} + for metric in metrics: + msg = f'Evaluating {metric} ...' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + eval_result = ava_eval( + temp_file, + metric, + self.label_file, + self.ann_file, + self.exclude_file, + custom_classes=self.custom_classes) + log_msg = [] + for k, v in eval_result.items(): + log_msg.append(f'\n{k}\t{v: .4f}') + log_msg = ''.join(log_msg) + print_log(log_msg, logger=logger) + ret.update(eval_result) + + os.remove(temp_file) + + return ret