Switch to side-by-side view

--- a
+++ b/mmaction/datasets/activitynet_dataset.py
@@ -0,0 +1,270 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os
+import os.path as osp
+import warnings
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+
+from ..core import average_recall_at_avg_proposals
+from .base import BaseDataset
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class ActivityNetDataset(BaseDataset):
+    """ActivityNet dataset for temporal action localization.
+
+    The dataset loads raw features and apply specified transforms to return a
+    dict containing the frame tensors and other information.
+
+    The ann_file is a json file with multiple objects, and each object has a
+    key of the name of a video, and value of total frames of the video, total
+    seconds of the video, annotations of a video, feature frames (frames
+    covered by features) of the video, fps and rfps. Example of a
+    annotation file:
+
+    .. code-block:: JSON
+
+        {
+            "v_--1DO2V4K74":  {
+                "duration_second": 211.53,
+                "duration_frame": 6337,
+                "annotations": [
+                    {
+                        "segment": [
+                            30.025882995319815,
+                            205.2318595943838
+                        ],
+                        "label": "Rock climbing"
+                    }
+                ],
+                "feature_frame": 6336,
+                "fps": 30.0,
+                "rfps": 29.9579255898
+            },
+            "v_--6bJUbfpnQ": {
+                "duration_second": 26.75,
+                "duration_frame": 647,
+                "annotations": [
+                    {
+                        "segment": [
+                            2.578755070202808,
+                            24.914101404056165
+                        ],
+                        "label": "Drinking beer"
+                    }
+                ],
+                "feature_frame": 624,
+                "fps": 24.0,
+                "rfps": 24.1869158879
+            },
+            ...
+        }
+
+
+    Args:
+        ann_file (str): Path to the annotation file.
+        pipeline (list[dict | callable]): A sequence of data transforms.
+        data_prefix (str | None): Path to a directory where videos are held.
+            Default: None.
+        test_mode (bool): Store True when building test or validation dataset.
+            Default: False.
+    """
+
+    def __init__(self, ann_file, pipeline, data_prefix=None, test_mode=False):
+        super().__init__(ann_file, pipeline, data_prefix, test_mode)
+
+    def load_annotations(self):
+        """Load the annotation according to ann_file into video_infos."""
+        video_infos = []
+        anno_database = mmcv.load(self.ann_file)
+        for video_name in anno_database:
+            video_info = anno_database[video_name]
+            video_info['video_name'] = video_name
+            video_infos.append(video_info)
+        return video_infos
+
+    def prepare_test_frames(self, idx):
+        """Prepare the frames for testing given the index."""
+        results = copy.deepcopy(self.video_infos[idx])
+        results['data_prefix'] = self.data_prefix
+        return self.pipeline(results)
+
+    def prepare_train_frames(self, idx):
+        """Prepare the frames for training given the index."""
+        results = copy.deepcopy(self.video_infos[idx])
+        results['data_prefix'] = self.data_prefix
+        return self.pipeline(results)
+
+    def __len__(self):
+        """Get the size of the dataset."""
+        return len(self.video_infos)
+
+    def _import_ground_truth(self):
+        """Read ground truth data from video_infos."""
+        ground_truth = {}
+        for video_info in self.video_infos:
+            video_id = video_info['video_name'][2:]
+            this_video_ground_truths = []
+            for ann in video_info['annotations']:
+                t_start, t_end = ann['segment']
+                label = ann['label']
+                this_video_ground_truths.append([t_start, t_end, label])
+            ground_truth[video_id] = np.array(this_video_ground_truths)
+        return ground_truth
+
+    @staticmethod
+    def proposals2json(results, show_progress=False):
+        """Convert all proposals to a final dict(json) format.
+
+        Args:
+            results (list[dict]): All proposals.
+            show_progress (bool): Whether to show the progress bar.
+                Defaults: False.
+
+        Returns:
+            dict: The final result dict. E.g.
+
+            .. code-block:: Python
+
+                dict(video-1=[dict(segment=[1.1,2.0]. score=0.9),
+                              dict(segment=[50.1, 129.3], score=0.6)])
+        """
+        result_dict = {}
+        print('Convert proposals to json format')
+        if show_progress:
+            prog_bar = mmcv.ProgressBar(len(results))
+        for result in results:
+            video_name = result['video_name']
+            result_dict[video_name[2:]] = result['proposal_list']
+            if show_progress:
+                prog_bar.update()
+        return result_dict
+
+    @staticmethod
+    def _import_proposals(results):
+        """Read predictions from results."""
+        proposals = {}
+        num_proposals = 0
+        for result in results:
+            video_id = result['video_name'][2:]
+            this_video_proposals = []
+            for proposal in result['proposal_list']:
+                t_start, t_end = proposal['segment']
+                score = proposal['score']
+                this_video_proposals.append([t_start, t_end, score])
+                num_proposals += 1
+            proposals[video_id] = np.array(this_video_proposals)
+        return proposals, num_proposals
+
+    def dump_results(self, results, out, output_format, version='VERSION 1.3'):
+        """Dump data to json/csv files."""
+        if output_format == 'json':
+            result_dict = self.proposals2json(results)
+            output_dict = {
+                'version': version,
+                'results': result_dict,
+                'external_data': {}
+            }
+            mmcv.dump(output_dict, out)
+        elif output_format == 'csv':
+            # TODO: add csv handler to mmcv and use mmcv.dump
+            os.makedirs(out, exist_ok=True)
+            header = 'action,start,end,tmin,tmax'
+            for result in results:
+                video_name, outputs = result
+                output_path = osp.join(out, video_name + '.csv')
+                np.savetxt(
+                    output_path,
+                    outputs,
+                    header=header,
+                    delimiter=',',
+                    comments='')
+        else:
+            raise ValueError(
+                f'The output format {output_format} is not supported.')
+
+    def evaluate(
+            self,
+            results,
+            metrics='AR@AN',
+            metric_options={
+                'AR@AN':
+                dict(
+                    max_avg_proposals=100,
+                    temporal_iou_thresholds=np.linspace(0.5, 0.95, 10))
+            },
+            logger=None,
+            **deprecated_kwargs):
+        """Evaluation in feature dataset.
+
+        Args:
+            results (list[dict]): Output results.
+            metrics (str | sequence[str]): Metrics to be performed.
+                Defaults: 'AR@AN'.
+            metric_options (dict): Dict for metric options. Options are
+                ``max_avg_proposals``, ``temporal_iou_thresholds`` for
+                ``AR@AN``.
+                default: ``{'AR@AN': dict(max_avg_proposals=100,
+                temporal_iou_thresholds=np.linspace(0.5, 0.95, 10))}``.
+            logger (logging.Logger | None): Training logger. Defaults: None.
+            deprecated_kwargs (dict): Used for containing deprecated arguments.
+                See 'https://github.com/open-mmlab/mmaction2/pull/286'.
+
+        Returns:
+            dict: Evaluation results for evaluation metrics.
+        """
+        # Protect ``metric_options`` since it uses mutable value as default
+        metric_options = copy.deepcopy(metric_options)
+
+        if deprecated_kwargs != {}:
+            warnings.warn(
+                'Option arguments for metrics has been changed to '
+                "`metric_options`, See 'https://github.com/open-mmlab/mmaction2/pull/286' "  # noqa: E501
+                'for more details')
+            metric_options['AR@AN'] = dict(metric_options['AR@AN'],
+                                           **deprecated_kwargs)
+
+        if not isinstance(results, list):
+            raise TypeError(f'results must be a list, but got {type(results)}')
+        assert len(results) == len(self), (
+            f'The length of results is not equal to the dataset len: '
+            f'{len(results)} != {len(self)}')
+
+        metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
+        allowed_metrics = ['AR@AN']
+        for metric in metrics:
+            if metric not in allowed_metrics:
+                raise KeyError(f'metric {metric} is not supported')
+
+        eval_results = OrderedDict()
+        ground_truth = self._import_ground_truth()
+        proposal, num_proposals = self._import_proposals(results)
+
+        for metric in metrics:
+            if metric == 'AR@AN':
+                temporal_iou_thresholds = metric_options.setdefault(
+                    'AR@AN', {}).setdefault('temporal_iou_thresholds',
+                                            np.linspace(0.5, 0.95, 10))
+                max_avg_proposals = metric_options.setdefault(
+                    'AR@AN', {}).setdefault('max_avg_proposals', 100)
+                if isinstance(temporal_iou_thresholds, list):
+                    temporal_iou_thresholds = np.array(temporal_iou_thresholds)
+
+                recall, _, _, auc = (
+                    average_recall_at_avg_proposals(
+                        ground_truth,
+                        proposal,
+                        num_proposals,
+                        max_avg_proposals=max_avg_proposals,
+                        temporal_iou_thresholds=temporal_iou_thresholds))
+                eval_results['auc'] = auc
+                eval_results['AR@1'] = np.mean(recall[:, 0])
+                eval_results['AR@5'] = np.mean(recall[:, 4])
+                eval_results['AR@10'] = np.mean(recall[:, 9])
+                eval_results['AR@100'] = np.mean(recall[:, 99])
+
+        return eval_results