[6d389a]: / mmaction / datasets / pose_dataset.py

Download this file

114 lines (96 with data), 4.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
from ..utils import get_root_logger
from .base import BaseDataset
from .builder import DATASETS
@DATASETS.register_module()
class PoseDataset(BaseDataset):
"""Pose dataset for action recognition.
The dataset loads pose and apply specified transforms to return a
dict containing pose information.
The ann_file is a pickle file, the json file contains a list of
annotations, the fields of an annotation include frame_dir(video_id),
total_frames, label, kp, kpscore.
Args:
ann_file (str): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transforms.
split (str | None): The dataset split used. Only applicable to UCF or
HMDB. Allowed choiced are 'train1', 'test1', 'train2', 'test2',
'train3', 'test3'. Default: None.
valid_ratio (float | None): The valid_ratio for videos in KineticsPose.
For a video with n frames, it is a valid training sample only if
n * valid_ratio frames have human pose. None means not applicable
(only applicable to Kinetics Pose). Default: None.
box_thr (str | None): The threshold for human proposals. Only boxes
with confidence score larger than `box_thr` is kept. None means
not applicable (only applicable to Kinetics Pose [ours]). Allowed
choices are '0.5', '0.6', '0.7', '0.8', '0.9'. Default: None.
class_prob (dict | None): The per class sampling probability. If not
None, it will override the class_prob calculated in
BaseDataset.__init__(). Default: None.
**kwargs: Keyword arguments for ``BaseDataset``.
"""
def __init__(self,
ann_file,
pipeline,
split=None,
valid_ratio=None,
box_thr=None,
class_prob=None,
**kwargs):
modality = 'Pose'
# split, applicable to ucf or hmdb
self.split = split
super().__init__(
ann_file, pipeline, start_index=0, modality=modality, **kwargs)
# box_thr, which should be a string
self.box_thr = box_thr
if self.box_thr is not None:
assert box_thr in ['0.5', '0.6', '0.7', '0.8', '0.9']
# Thresholding Training Examples
self.valid_ratio = valid_ratio
if self.valid_ratio is not None:
assert isinstance(self.valid_ratio, float)
if self.box_thr is None:
self.video_infos = self.video_infos = [
x for x in self.video_infos
if x['valid_frames'] / x['total_frames'] >= valid_ratio
]
else:
key = f'valid@{self.box_thr}'
self.video_infos = [
x for x in self.video_infos
if x[key] / x['total_frames'] >= valid_ratio
]
if self.box_thr != '0.5':
box_thr = float(self.box_thr)
for item in self.video_infos:
inds = [
i for i, score in enumerate(item['box_score'])
if score >= box_thr
]
item['anno_inds'] = np.array(inds)
if class_prob is not None:
self.class_prob = class_prob
logger = get_root_logger()
logger.info(f'{len(self)} videos remain after valid thresholding')
def load_annotations(self):
"""Load annotation file to get video information."""
assert self.ann_file.endswith('.pkl')
return self.load_pkl_annotations()
def load_pkl_annotations(self):
data = mmcv.load(self.ann_file)
if self.split:
split, data = data['split'], data['annotations']
identifier = 'filename' if 'filename' in data[0] else 'frame_dir'
data = [x for x in data if x[identifier] in split[self.split]]
for item in data:
# Sometimes we may need to load anno from the file
if 'filename' in item:
item['filename'] = osp.join(self.data_prefix, item['filename'])
if 'frame_dir' in item:
item['frame_dir'] = osp.join(self.data_prefix,
item['frame_dir'])
return data