--- a +++ b/mmaction/datasets/base.py @@ -0,0 +1,289 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict, defaultdict + +import mmcv +import numpy as np +import torch +from mmcv.utils import print_log +from torch.utils.data import Dataset + +from ..core import (mean_average_precision, mean_class_accuracy, + mmit_mean_average_precision, top_k_accuracy) +from .pipelines import Compose + + +class BaseDataset(Dataset, metaclass=ABCMeta): + """Base class for datasets. + + All datasets to process video should subclass it. + All subclasses should overwrite: + + - Methods:`load_annotations`, supporting to load information from an + annotation file. + - Methods:`prepare_train_frames`, providing train data. + - Methods:`prepare_test_frames`, providing test data. + + Args: + ann_file (str): Path to the annotation file. + pipeline (list[dict | callable]): A sequence of data transforms. + data_prefix (str | None): Path to a directory where videos are held. + Default: None. + test_mode (bool): Store True when building test or validation dataset. + Default: False. + multi_class (bool): Determines whether the dataset is a multi-class + dataset. Default: False. + num_classes (int | None): Number of classes of the dataset, used in + multi-class datasets. Default: None. + start_index (int): Specify a start index for frames in consideration of + different filename format. However, when taking videos as input, + it should be set to 0, since frames loaded from videos count + from 0. Default: 1. + modality (str): Modality of data. Support 'RGB', 'Flow', 'Audio'. + Default: 'RGB'. + sample_by_class (bool): Sampling by class, should be set `True` when + performing inter-class data balancing. Only compatible with + `multi_class == False`. Only applies for training. Default: False. + power (float): We support sampling data with the probability + proportional to the power of its label frequency (freq ^ power) + when sampling data. `power == 1` indicates uniformly sampling all + data; `power == 0` indicates uniformly sampling all classes. + Default: 0. + dynamic_length (bool): If the dataset length is dynamic (used by + ClassSpecificDistributedSampler). Default: False. + """ + + def __init__(self, + ann_file, + pipeline, + data_prefix=None, + test_mode=False, + multi_class=False, + num_classes=None, + start_index=1, + modality='RGB', + sample_by_class=False, + power=0, + dynamic_length=False): + super().__init__() + + self.ann_file = ann_file + self.data_prefix = osp.realpath( + data_prefix) if data_prefix is not None and osp.isdir( + data_prefix) else data_prefix + self.test_mode = test_mode + self.multi_class = multi_class + self.num_classes = num_classes + self.start_index = start_index + self.modality = modality + self.sample_by_class = sample_by_class + self.power = power + self.dynamic_length = dynamic_length + + assert not (self.multi_class and self.sample_by_class) + + self.pipeline = Compose(pipeline) + self.video_infos = self.load_annotations() + if self.sample_by_class: + self.video_infos_by_class = self.parse_by_class() + + class_prob = [] + for _, samples in self.video_infos_by_class.items(): + class_prob.append(len(samples) / len(self.video_infos)) + class_prob = [x**self.power for x in class_prob] + + summ = sum(class_prob) + class_prob = [x / summ for x in class_prob] + + self.class_prob = dict(zip(self.video_infos_by_class, class_prob)) + + @abstractmethod + def load_annotations(self): + """Load the annotation according to ann_file into video_infos.""" + + # json annotations already looks like video_infos, so for each dataset, + # this func should be the same + def load_json_annotations(self): + """Load json annotation file to get video information.""" + video_infos = mmcv.load(self.ann_file) + num_videos = len(video_infos) + path_key = 'frame_dir' if 'frame_dir' in video_infos[0] else 'filename' + for i in range(num_videos): + path_value = video_infos[i][path_key] + if self.data_prefix is not None: + path_value = osp.join(self.data_prefix, path_value) + video_infos[i][path_key] = path_value + if self.multi_class: + assert self.num_classes is not None + else: + assert len(video_infos[i]['label']) == 1 + video_infos[i]['label'] = video_infos[i]['label'][0] + return video_infos + + def parse_by_class(self): + video_infos_by_class = defaultdict(list) + for item in self.video_infos: + label = item['label'] + video_infos_by_class[label].append(item) + return video_infos_by_class + + @staticmethod + def label2array(num, label): + arr = np.zeros(num, dtype=np.float32) + arr[label] = 1. + return arr + + def evaluate(self, + results, + metrics='top_k_accuracy', + metric_options=dict(top_k_accuracy=dict(topk=(1, 5))), + logger=None, + **deprecated_kwargs): + """Perform evaluation for common datasets. + + Args: + results (list): Output results. + metrics (str | sequence[str]): Metrics to be performed. + Defaults: 'top_k_accuracy'. + metric_options (dict): Dict for metric options. Options are + ``topk`` for ``top_k_accuracy``. + Default: ``dict(top_k_accuracy=dict(topk=(1, 5)))``. + logger (logging.Logger | None): Logger for recording. + Default: None. + deprecated_kwargs (dict): Used for containing deprecated arguments. + See 'https://github.com/open-mmlab/mmaction2/pull/286'. + + Returns: + dict: Evaluation results dict. + """ + # Protect ``metric_options`` since it uses mutable value as default + metric_options = copy.deepcopy(metric_options) + + if deprecated_kwargs != {}: + warnings.warn( + 'Option arguments for metrics has been changed to ' + "`metric_options`, See 'https://github.com/open-mmlab/mmaction2/pull/286' " # noqa: E501 + 'for more details') + metric_options['top_k_accuracy'] = dict( + metric_options['top_k_accuracy'], **deprecated_kwargs) + + if not isinstance(results, list): + raise TypeError(f'results must be a list, but got {type(results)}') + assert len(results) == len(self), ( + f'The length of results is not equal to the dataset len: ' + f'{len(results)} != {len(self)}') + + metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics] + allowed_metrics = [ + 'top_k_accuracy', 'mean_class_accuracy', 'mean_average_precision', + 'mmit_mean_average_precision' + ] + + for metric in metrics: + if metric not in allowed_metrics: + raise KeyError(f'metric {metric} is not supported') + + eval_results = OrderedDict() + gt_labels = [ann['label'] for ann in self.video_infos] + + for metric in metrics: + msg = f'Evaluating {metric} ...' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + if metric == 'top_k_accuracy': + topk = metric_options.setdefault('top_k_accuracy', + {}).setdefault( + 'topk', (1, 5)) + if not isinstance(topk, (int, tuple)): + raise TypeError('topk must be int or tuple of int, ' + f'but got {type(topk)}') + if isinstance(topk, int): + topk = (topk, ) + + top_k_acc = top_k_accuracy(results, gt_labels, topk) + log_msg = [] + for k, acc in zip(topk, top_k_acc): + eval_results[f'top{k}_acc'] = acc + log_msg.append(f'\ntop{k}_acc\t{acc:.4f}') + log_msg = ''.join(log_msg) + print_log(log_msg, logger=logger) + continue + + if metric == 'mean_class_accuracy': + mean_acc = mean_class_accuracy(results, gt_labels) + eval_results['mean_class_accuracy'] = mean_acc + log_msg = f'\nmean_acc\t{mean_acc:.4f}' + print_log(log_msg, logger=logger) + continue + + if metric in [ + 'mean_average_precision', 'mmit_mean_average_precision' + ]: + gt_labels_arrays = [ + self.label2array(self.num_classes, label) + for label in gt_labels + ] + if metric == 'mean_average_precision': + mAP = mean_average_precision(results, gt_labels_arrays) + eval_results['mean_average_precision'] = mAP + log_msg = f'\nmean_average_precision\t{mAP:.4f}' + elif metric == 'mmit_mean_average_precision': + mAP = mmit_mean_average_precision(results, + gt_labels_arrays) + eval_results['mmit_mean_average_precision'] = mAP + log_msg = f'\nmmit_mean_average_precision\t{mAP:.4f}' + print_log(log_msg, logger=logger) + continue + + return eval_results + + @staticmethod + def dump_results(results, out): + """Dump data to json/yaml/pickle strings or files.""" + return mmcv.dump(results, out) + + def prepare_train_frames(self, idx): + """Prepare the frames for training given the index.""" + results = copy.deepcopy(self.video_infos[idx]) + results['modality'] = self.modality + results['start_index'] = self.start_index + + # prepare tensor in getitem + # If HVU, type(results['label']) is dict + if self.multi_class and isinstance(results['label'], list): + onehot = torch.zeros(self.num_classes) + onehot[results['label']] = 1. + results['label'] = onehot + + return self.pipeline(results) + + def prepare_test_frames(self, idx): + """Prepare the frames for testing given the index.""" + results = copy.deepcopy(self.video_infos[idx]) + results['modality'] = self.modality + results['start_index'] = self.start_index + + # prepare tensor in getitem + # If HVU, type(results['label']) is dict + if self.multi_class and isinstance(results['label'], list): + onehot = torch.zeros(self.num_classes) + onehot[results['label']] = 1. + results['label'] = onehot + + return self.pipeline(results) + + def __len__(self): + """Get the size of the dataset.""" + return len(self.video_infos) + + def __getitem__(self, idx): + """Get the sample for either training or testing given index.""" + if self.test_mode: + return self.prepare_test_frames(idx) + + return self.prepare_train_frames(idx)