# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import torch
from .base import BaseDataset
from .builder import DATASETS
@DATASETS.register_module()
class AudioDataset(BaseDataset):
"""Audio dataset for video recognition. Extracts the audio feature on-the-
fly. Annotation file can be that of the rawframe dataset, or:
.. code-block:: txt
some/directory-1.wav 163 1
some/directory-2.wav 122 1
some/directory-3.wav 258 2
some/directory-4.wav 234 2
some/directory-5.wav 295 3
some/directory-6.wav 121 3
Args:
ann_file (str): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transforms.
suffix (str): The suffix of the audio file. Default: '.wav'.
kwargs (dict): Other keyword args for `BaseDataset`.
"""
def __init__(self, ann_file, pipeline, suffix='.wav', **kwargs):
self.suffix = suffix
super().__init__(ann_file, pipeline, modality='Audio', **kwargs)
def load_annotations(self):
"""Load annotation file to get video information."""
if self.ann_file.endswith('.json'):
return self.load_json_annotations()
video_infos = []
with open(self.ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split()
video_info = {}
idx = 0
filename = line_split[idx]
if self.data_prefix is not None:
if not filename.endswith(self.suffix):
filename = osp.join(self.data_prefix,
filename + self.suffix)
else:
filename = osp.join(self.data_prefix, filename)
video_info['audio_path'] = filename
idx += 1
# idx for total_frames
video_info['total_frames'] = int(line_split[idx])
idx += 1
# idx for label[s]
label = [int(x) for x in line_split[idx:]]
assert label, f'missing label in line: {line}'
if self.multi_class:
assert self.num_classes is not None
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
video_info['label'] = onehot
else:
assert len(label) == 1
video_info['label'] = label[0]
video_infos.append(video_info)
return video_infos