--- a +++ b/mmaction/datasets/pose_dataset.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np + +from ..utils import get_root_logger +from .base import BaseDataset +from .builder import DATASETS + + +@DATASETS.register_module() +class PoseDataset(BaseDataset): + """Pose dataset for action recognition. + + The dataset loads pose and apply specified transforms to return a + dict containing pose information. + + The ann_file is a pickle file, the json file contains a list of + annotations, the fields of an annotation include frame_dir(video_id), + total_frames, label, kp, kpscore. + + Args: + ann_file (str): Path to the annotation file. + pipeline (list[dict | callable]): A sequence of data transforms. + split (str | None): The dataset split used. Only applicable to UCF or + HMDB. Allowed choiced are 'train1', 'test1', 'train2', 'test2', + 'train3', 'test3'. Default: None. + valid_ratio (float | None): The valid_ratio for videos in KineticsPose. + For a video with n frames, it is a valid training sample only if + n * valid_ratio frames have human pose. None means not applicable + (only applicable to Kinetics Pose). Default: None. + box_thr (str | None): The threshold for human proposals. Only boxes + with confidence score larger than `box_thr` is kept. None means + not applicable (only applicable to Kinetics Pose [ours]). Allowed + choices are '0.5', '0.6', '0.7', '0.8', '0.9'. Default: None. + class_prob (dict | None): The per class sampling probability. If not + None, it will override the class_prob calculated in + BaseDataset.__init__(). Default: None. + **kwargs: Keyword arguments for ``BaseDataset``. + """ + + def __init__(self, + ann_file, + pipeline, + split=None, + valid_ratio=None, + box_thr=None, + class_prob=None, + **kwargs): + modality = 'Pose' + # split, applicable to ucf or hmdb + self.split = split + + super().__init__( + ann_file, pipeline, start_index=0, modality=modality, **kwargs) + + # box_thr, which should be a string + self.box_thr = box_thr + if self.box_thr is not None: + assert box_thr in ['0.5', '0.6', '0.7', '0.8', '0.9'] + + # Thresholding Training Examples + self.valid_ratio = valid_ratio + if self.valid_ratio is not None: + assert isinstance(self.valid_ratio, float) + if self.box_thr is None: + self.video_infos = self.video_infos = [ + x for x in self.video_infos + if x['valid_frames'] / x['total_frames'] >= valid_ratio + ] + else: + key = f'valid@{self.box_thr}' + self.video_infos = [ + x for x in self.video_infos + if x[key] / x['total_frames'] >= valid_ratio + ] + if self.box_thr != '0.5': + box_thr = float(self.box_thr) + for item in self.video_infos: + inds = [ + i for i, score in enumerate(item['box_score']) + if score >= box_thr + ] + item['anno_inds'] = np.array(inds) + + if class_prob is not None: + self.class_prob = class_prob + + logger = get_root_logger() + logger.info(f'{len(self)} videos remain after valid thresholding') + + def load_annotations(self): + """Load annotation file to get video information.""" + assert self.ann_file.endswith('.pkl') + return self.load_pkl_annotations() + + def load_pkl_annotations(self): + data = mmcv.load(self.ann_file) + + if self.split: + split, data = data['split'], data['annotations'] + identifier = 'filename' if 'filename' in data[0] else 'frame_dir' + data = [x for x in data if x[identifier] in split[self.split]] + + for item in data: + # Sometimes we may need to load anno from the file + if 'filename' in item: + item['filename'] = osp.join(self.data_prefix, item['filename']) + if 'frame_dir' in item: + item['frame_dir'] = osp.join(self.data_prefix, + item['frame_dir']) + return data