--- a
+++ b/mmaction/datasets/ava_dataset.py
@@ -0,0 +1,392 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os
+import os.path as osp
+from collections import defaultdict
+from datetime import datetime
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+
+from ..core.evaluation.ava_utils import ava_eval, read_labelmap, results2csv
+from ..utils import get_root_logger
+from .base import BaseDataset
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class AVADataset(BaseDataset):
+    """AVA dataset for spatial temporal detection.
+
+    Based on official AVA annotation files, the dataset loads raw frames,
+    bounding boxes, proposals and applies specified transformations to return
+    a dict containing the frame tensors and other information.
+
+    This datasets can load information from the following files:
+
+    .. code-block:: txt
+
+        ann_file -> ava_{train, val}_{v2.1, v2.2}.csv
+        exclude_file -> ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv
+        label_file -> ava_action_list_{v2.1, v2.2}.pbtxt /
+                      ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt
+        proposal_file -> ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl
+
+    Particularly, the proposal_file is a pickle file which contains
+    ``img_key`` (in format of ``{video_id},{timestamp}``). Example of a pickle
+    file:
+
+    .. code-block:: JSON
+
+        {
+            ...
+            '0f39OWEqJ24,0902':
+                array([[0.011   , 0.157   , 0.655   , 0.983   , 0.998163]]),
+            '0f39OWEqJ24,0912':
+                array([[0.054   , 0.088   , 0.91    , 0.998   , 0.068273],
+                       [0.016   , 0.161   , 0.519   , 0.974   , 0.984025],
+                       [0.493   , 0.283   , 0.981   , 0.984   , 0.983621]]),
+            ...
+        }
+
+    Args:
+        ann_file (str): Path to the annotation file like
+            ``ava_{train, val}_{v2.1, v2.2}.csv``.
+        exclude_file (str): Path to the excluded timestamp file like
+            ``ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv``.
+        pipeline (list[dict | callable]): A sequence of data transforms.
+        label_file (str): Path to the label file like
+            ``ava_action_list_{v2.1, v2.2}.pbtxt`` or
+            ``ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt``.
+            Default: None.
+        filename_tmpl (str): Template for each filename.
+            Default: 'img_{:05}.jpg'.
+        start_index (int): Specify a start index for frames in consideration of
+            different filename format. However, when taking videos as input,
+            it should be set to 0, since frames loaded from videos count
+            from 0. Default: 0.
+        proposal_file (str): Path to the proposal file like
+            ``ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl``.
+            Default: None.
+        person_det_score_thr (float): The threshold of person detection scores,
+            bboxes with scores above the threshold will be used. Default: 0.9.
+            Note that 0 <= person_det_score_thr <= 1. If no proposal has
+            detection score larger than the threshold, the one with the largest
+            detection score will be used.
+        num_classes (int): The number of classes of the dataset. Default: 81.
+            (AVA has 80 action classes, another 1-dim is added for potential
+            usage)
+        custom_classes (list[int]): A subset of class ids from origin dataset.
+            Please note that 0 should NOT be selected, and ``num_classes``
+            should be equal to ``len(custom_classes) + 1``
+        data_prefix (str): Path to a directory where videos are held.
+            Default: None.
+        test_mode (bool): Store True when building test or validation dataset.
+            Default: False.
+        modality (str): Modality of data. Support 'RGB', 'Flow'.
+                        Default: 'RGB'.
+        num_max_proposals (int): Max proposals number to store. Default: 1000.
+        timestamp_start (int): The start point of included timestamps. The
+            default value is referred from the official website. Default: 902.
+        timestamp_end (int): The end point of included timestamps. The
+            default value is referred from the official website. Default: 1798.
+    """
+
+    _FPS = 30
+
+    def __init__(self,
+                 ann_file,
+                 exclude_file,
+                 pipeline,
+                 label_file=None,
+                 filename_tmpl='img_{:05}.jpg',
+                 start_index=0,
+                 proposal_file=None,
+                 person_det_score_thr=0.9,
+                 num_classes=81,
+                 custom_classes=None,
+                 data_prefix=None,
+                 test_mode=False,
+                 modality='RGB',
+                 num_max_proposals=1000,
+                 timestamp_start=900,
+                 timestamp_end=1800):
+        # since it inherits from `BaseDataset`, some arguments
+        # should be assigned before performing `load_annotations()`
+        self.custom_classes = custom_classes
+        if custom_classes is not None:
+            assert num_classes == len(custom_classes) + 1
+            assert 0 not in custom_classes
+            _, class_whitelist = read_labelmap(open(label_file))
+            assert set(custom_classes).issubset(class_whitelist)
+
+            self.custom_classes = tuple([0] + custom_classes)
+        self.exclude_file = exclude_file
+        self.label_file = label_file
+        self.proposal_file = proposal_file
+        assert 0 <= person_det_score_thr <= 1, (
+            'The value of '
+            'person_det_score_thr should in [0, 1]. ')
+        self.person_det_score_thr = person_det_score_thr
+        self.num_classes = num_classes
+        self.filename_tmpl = filename_tmpl
+        self.num_max_proposals = num_max_proposals
+        self.timestamp_start = timestamp_start
+        self.timestamp_end = timestamp_end
+        self.logger = get_root_logger()
+        super().__init__(
+            ann_file,
+            pipeline,
+            data_prefix,
+            test_mode,
+            start_index=start_index,
+            modality=modality,
+            num_classes=num_classes)
+
+        if self.proposal_file is not None:
+            self.proposals = mmcv.load(self.proposal_file)
+        else:
+            self.proposals = None
+
+        if not test_mode:
+            valid_indexes = self.filter_exclude_file()
+            self.logger.info(
+                f'{len(valid_indexes)} out of {len(self.video_infos)} '
+                f'frames are valid.')
+            self.video_infos = [self.video_infos[i] for i in valid_indexes]
+
+    def parse_img_record(self, img_records):
+        """Merge image records of the same entity at the same time.
+
+        Args:
+            img_records (list[dict]): List of img_records (lines in AVA
+                annotations).
+
+        Returns:
+            tuple(list): A tuple consists of lists of bboxes, action labels and
+                entity_ids
+        """
+        bboxes, labels, entity_ids = [], [], []
+        while len(img_records) > 0:
+            img_record = img_records[0]
+            num_img_records = len(img_records)
+
+            selected_records = [
+                x for x in img_records
+                if np.array_equal(x['entity_box'], img_record['entity_box'])
+            ]
+
+            num_selected_records = len(selected_records)
+            img_records = [
+                x for x in img_records if
+                not np.array_equal(x['entity_box'], img_record['entity_box'])
+            ]
+
+            assert len(img_records) + num_selected_records == num_img_records
+
+            bboxes.append(img_record['entity_box'])
+            valid_labels = np.array([
+                selected_record['label']
+                for selected_record in selected_records
+            ])
+
+            # The format can be directly used by BCELossWithLogits
+            label = np.zeros(self.num_classes, dtype=np.float32)
+            label[valid_labels] = 1.
+
+            labels.append(label)
+            entity_ids.append(img_record['entity_id'])
+
+        bboxes = np.stack(bboxes)
+        labels = np.stack(labels)
+        entity_ids = np.stack(entity_ids)
+        return bboxes, labels, entity_ids
+
+    def filter_exclude_file(self):
+        """Filter out records in the exclude_file."""
+        valid_indexes = []
+        if self.exclude_file is None:
+            valid_indexes = list(range(len(self.video_infos)))
+        else:
+            exclude_video_infos = [
+                x.strip().split(',') for x in open(self.exclude_file)
+            ]
+            for i, video_info in enumerate(self.video_infos):
+                valid_indexes.append(i)
+                for video_id, timestamp in exclude_video_infos:
+                    if (video_info['video_id'] == video_id
+                            and video_info['timestamp'] == int(timestamp)):
+                        valid_indexes.pop()
+                        break
+        return valid_indexes
+
+    def load_annotations(self):
+        """Load AVA annotations."""
+        video_infos = []
+        records_dict_by_img = defaultdict(list)
+        with open(self.ann_file, 'r') as fin:
+            for line in fin:
+                line_split = line.strip().split(',')
+
+                label = int(line_split[6])
+                if self.custom_classes is not None:
+                    if label not in self.custom_classes:
+                        continue
+                    label = self.custom_classes.index(label)
+
+                video_id = line_split[0]
+                timestamp = int(line_split[1])
+                img_key = f'{video_id},{timestamp:04d}'
+
+                entity_box = np.array(list(map(float, line_split[2:6])))
+                entity_id = int(line_split[7])
+                shot_info = (0, (self.timestamp_end - self.timestamp_start) *
+                             self._FPS)
+
+                video_info = dict(
+                    video_id=video_id,
+                    timestamp=timestamp,
+                    entity_box=entity_box,
+                    label=label,
+                    entity_id=entity_id,
+                    shot_info=shot_info)
+                records_dict_by_img[img_key].append(video_info)
+
+        for img_key in records_dict_by_img:
+            video_id, timestamp = img_key.split(',')
+            bboxes, labels, entity_ids = self.parse_img_record(
+                records_dict_by_img[img_key])
+            ann = dict(
+                gt_bboxes=bboxes, gt_labels=labels, entity_ids=entity_ids)
+            frame_dir = video_id
+            if self.data_prefix is not None:
+                frame_dir = osp.join(self.data_prefix, frame_dir)
+            video_info = dict(
+                frame_dir=frame_dir,
+                video_id=video_id,
+                timestamp=int(timestamp),
+                img_key=img_key,
+                shot_info=shot_info,
+                fps=self._FPS,
+                ann=ann)
+            video_infos.append(video_info)
+
+        return video_infos
+
+    def prepare_train_frames(self, idx):
+        """Prepare the frames for training given the index."""
+        results = copy.deepcopy(self.video_infos[idx])
+        img_key = results['img_key']
+
+        results['filename_tmpl'] = self.filename_tmpl
+        results['modality'] = self.modality
+        results['start_index'] = self.start_index
+        results['timestamp_start'] = self.timestamp_start
+        results['timestamp_end'] = self.timestamp_end
+
+        if self.proposals is not None:
+            if img_key not in self.proposals:
+                results['proposals'] = np.array([[0, 0, 1, 1]])
+                results['scores'] = np.array([1])
+            else:
+                proposals = self.proposals[img_key]
+                assert proposals.shape[-1] in [4, 5]
+                if proposals.shape[-1] == 5:
+                    thr = min(self.person_det_score_thr, max(proposals[:, 4]))
+                    positive_inds = (proposals[:, 4] >= thr)
+                    proposals = proposals[positive_inds]
+                    proposals = proposals[:self.num_max_proposals]
+                    results['proposals'] = proposals[:, :4]
+                    results['scores'] = proposals[:, 4]
+                else:
+                    proposals = proposals[:self.num_max_proposals]
+                    results['proposals'] = proposals
+
+        ann = results.pop('ann')
+        results['gt_bboxes'] = ann['gt_bboxes']
+        results['gt_labels'] = ann['gt_labels']
+        results['entity_ids'] = ann['entity_ids']
+
+        return self.pipeline(results)
+
+    def prepare_test_frames(self, idx):
+        """Prepare the frames for testing given the index."""
+        results = copy.deepcopy(self.video_infos[idx])
+        img_key = results['img_key']
+
+        results['filename_tmpl'] = self.filename_tmpl
+        results['modality'] = self.modality
+        results['start_index'] = self.start_index
+        results['timestamp_start'] = self.timestamp_start
+        results['timestamp_end'] = self.timestamp_end
+
+        if self.proposals is not None:
+            if img_key not in self.proposals:
+                results['proposals'] = np.array([[0, 0, 1, 1]])
+                results['scores'] = np.array([1])
+            else:
+                proposals = self.proposals[img_key]
+                assert proposals.shape[-1] in [4, 5]
+                if proposals.shape[-1] == 5:
+                    thr = min(self.person_det_score_thr, max(proposals[:, 4]))
+                    positive_inds = (proposals[:, 4] >= thr)
+                    proposals = proposals[positive_inds]
+                    proposals = proposals[:self.num_max_proposals]
+                    results['proposals'] = proposals[:, :4]
+                    results['scores'] = proposals[:, 4]
+                else:
+                    proposals = proposals[:self.num_max_proposals]
+                    results['proposals'] = proposals
+
+        ann = results.pop('ann')
+        # Follow the mmdet variable naming style.
+        results['gt_bboxes'] = ann['gt_bboxes']
+        results['gt_labels'] = ann['gt_labels']
+        results['entity_ids'] = ann['entity_ids']
+
+        return self.pipeline(results)
+
+    def dump_results(self, results, out):
+        """Dump predictions into a csv file."""
+        assert out.endswith('csv')
+        results2csv(self, results, out, self.custom_classes)
+
+    def evaluate(self,
+                 results,
+                 metrics=('mAP', ),
+                 metric_options=None,
+                 logger=None):
+        """Evaluate the prediction results and report mAP."""
+        assert len(metrics) == 1 and metrics[0] == 'mAP', (
+            'For evaluation on AVADataset, you need to use metrics "mAP" '
+            'See https://github.com/open-mmlab/mmaction2/pull/567 '
+            'for more info.')
+        time_now = datetime.now().strftime('%Y%m%d_%H%M%S')
+        temp_file = f'AVA_{time_now}_result.csv'
+        results2csv(self, results, temp_file, self.custom_classes)
+
+        ret = {}
+        for metric in metrics:
+            msg = f'Evaluating {metric} ...'
+            if logger is None:
+                msg = '\n' + msg
+            print_log(msg, logger=logger)
+
+            eval_result = ava_eval(
+                temp_file,
+                metric,
+                self.label_file,
+                self.ann_file,
+                self.exclude_file,
+                custom_classes=self.custom_classes)
+            log_msg = []
+            for k, v in eval_result.items():
+                log_msg.append(f'\n{k}\t{v: .4f}')
+            log_msg = ''.join(log_msg)
+            print_log(log_msg, logger=logger)
+            ret.update(eval_result)
+
+        os.remove(temp_file)
+
+        return ret