--- a +++ b/cnnmodel/dataset.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision.datasets import DatasetFolder + + +class CNNDataset(Dataset): + def __init__(self, root): + self.dataset_folder = DatasetFolder(root=root, loader=CNNDataset._npy_loader, extensions=('_mfcc.npy',)) + self.len_ = len(self.dataset_folder) + self.folder_to_index = self.dataset_folder.class_to_idx + + @staticmethod + def _npy_loader(path): + mfcc = np.load(path) + non_mfcc_file_path = path.replace('mfcc', 'other') + non_mfcc = np.load(non_mfcc_file_path) + + # in_channels x height x width + assert mfcc.shape == (3, 13, 30) + assert non_mfcc.shape == (18, ) + + mfcc = torch.from_numpy(mfcc).float() + non_mfcc = torch.from_numpy(non_mfcc).float() + + return mfcc, non_mfcc, path + + def __getitem__(self, index): + + return self.dataset_folder[index] + + def __len__(self): + return self.len_