Switch to side-by-side view

--- a
+++ b/mmaction/datasets/hvu_dataset.py
@@ -0,0 +1,192 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os.path as osp
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+
+from ..core import mean_average_precision
+from .base import BaseDataset
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class HVUDataset(BaseDataset):
+    """HVU dataset, which supports the recognition tags of multiple categories.
+    Accept both video annotation files or rawframe annotation files.
+
+    The dataset loads videos or raw frames and applies specified transforms to
+    return a dict containing the frame tensors and other information.
+
+    The ann_file is a json file with multiple dictionaries, and each dictionary
+    indicates a sample video with the filename and tags, the tags are organized
+    as different categories. Example of a video dictionary:
+
+    .. code-block:: txt
+
+        {
+            'filename': 'gD_G1b0wV5I_001015_001035.mp4',
+            'label': {
+                'concept': [250, 131, 42, 51, 57, 155, 122],
+                'object': [1570, 508],
+                'event': [16],
+                'action': [180],
+                'scene': [206]
+            }
+        }
+
+    Example of a rawframe dictionary:
+
+    .. code-block:: txt
+
+        {
+            'frame_dir': 'gD_G1b0wV5I_001015_001035',
+            'total_frames': 61
+            'label': {
+                'concept': [250, 131, 42, 51, 57, 155, 122],
+                'object': [1570, 508],
+                'event': [16],
+                'action': [180],
+                'scene': [206]
+            }
+        }
+
+
+    Args:
+        ann_file (str): Path to the annotation file, should be a json file.
+        pipeline (list[dict | callable]): A sequence of data transforms.
+        tag_categories (list[str]): List of category names of tags.
+        tag_category_nums (list[int]): List of number of tags in each category.
+        filename_tmpl (str | None): Template for each filename. If set to None,
+            video dataset is used. Default: None.
+        **kwargs: Keyword arguments for ``BaseDataset``.
+    """
+
+    def __init__(self,
+                 ann_file,
+                 pipeline,
+                 tag_categories,
+                 tag_category_nums,
+                 filename_tmpl=None,
+                 **kwargs):
+        assert len(tag_categories) == len(tag_category_nums)
+        self.tag_categories = tag_categories
+        self.tag_category_nums = tag_category_nums
+        self.filename_tmpl = filename_tmpl
+        self.num_categories = len(self.tag_categories)
+        self.num_tags = sum(self.tag_category_nums)
+        self.category2num = dict(zip(tag_categories, tag_category_nums))
+        self.start_idx = [0]
+        for i in range(self.num_categories - 1):
+            self.start_idx.append(self.start_idx[-1] +
+                                  self.tag_category_nums[i])
+        self.category2startidx = dict(zip(tag_categories, self.start_idx))
+        self.start_index = kwargs.pop('start_index', 0)
+        self.dataset_type = None
+        super().__init__(
+            ann_file, pipeline, start_index=self.start_index, **kwargs)
+
+    def load_annotations(self):
+        """Load annotation file to get video information."""
+        assert self.ann_file.endswith('.json')
+        return self.load_json_annotations()
+
+    def load_json_annotations(self):
+        video_infos = mmcv.load(self.ann_file)
+        num_videos = len(video_infos)
+
+        video_info0 = video_infos[0]
+        assert ('filename' in video_info0) != ('frame_dir' in video_info0)
+        path_key = 'filename' if 'filename' in video_info0 else 'frame_dir'
+        self.dataset_type = 'video' if path_key == 'filename' else 'rawframe'
+        if self.dataset_type == 'rawframe':
+            assert self.filename_tmpl is not None
+
+        for i in range(num_videos):
+            path_value = video_infos[i][path_key]
+            if self.data_prefix is not None:
+                path_value = osp.join(self.data_prefix, path_value)
+            video_infos[i][path_key] = path_value
+
+            # We will convert label to torch tensors in the pipeline
+            video_infos[i]['categories'] = self.tag_categories
+            video_infos[i]['category_nums'] = self.tag_category_nums
+            if self.dataset_type == 'rawframe':
+                video_infos[i]['filename_tmpl'] = self.filename_tmpl
+                video_infos[i]['start_index'] = self.start_index
+                video_infos[i]['modality'] = self.modality
+
+        return video_infos
+
+    @staticmethod
+    def label2array(num, label):
+        arr = np.zeros(num, dtype=np.float32)
+        arr[label] = 1.
+        return arr
+
+    def evaluate(self,
+                 results,
+                 metrics='mean_average_precision',
+                 metric_options=None,
+                 logger=None):
+        """Evaluation in HVU Video Dataset. We only support evaluating mAP for
+        each tag categories. Since some tag categories are missing for some
+        videos, we can not evaluate mAP for all tags.
+
+        Args:
+            results (list): Output results.
+            metrics (str | sequence[str]): Metrics to be performed.
+                Defaults: 'mean_average_precision'.
+            metric_options (dict | None): Dict for metric options.
+                Default: None.
+            logger (logging.Logger | None): Logger for recording.
+                Default: None.
+
+        Returns:
+            dict: Evaluation results dict.
+        """
+        # Protect ``metric_options`` since it uses mutable value as default
+        metric_options = copy.deepcopy(metric_options)
+
+        if not isinstance(results, list):
+            raise TypeError(f'results must be a list, but got {type(results)}')
+        assert len(results) == len(self), (
+            f'The length of results is not equal to the dataset len: '
+            f'{len(results)} != {len(self)}')
+
+        metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
+
+        # There should be only one metric in the metrics list:
+        # 'mean_average_precision'
+        assert len(metrics) == 1
+        metric = metrics[0]
+        assert metric == 'mean_average_precision'
+
+        gt_labels = [ann['label'] for ann in self.video_infos]
+
+        eval_results = OrderedDict()
+
+        for category in self.tag_categories:
+
+            start_idx = self.category2startidx[category]
+            num = self.category2num[category]
+            preds = [
+                result[start_idx:start_idx + num]
+                for video_idx, result in enumerate(results)
+                if category in gt_labels[video_idx]
+            ]
+            gts = [
+                gt_label[category] for gt_label in gt_labels
+                if category in gt_label
+            ]
+
+            gts = [self.label2array(num, item) for item in gts]
+
+            mAP = mean_average_precision(preds, gts)
+            eval_results[f'{category}_mAP'] = mAP
+            log_msg = f'\n{category}_mAP\t{mAP:.4f}'
+            print_log(log_msg, logger=logger)
+
+        return eval_results