Diff of /cnnmodel/dataset.py [000000] .. [8c4e02]

Switch to side-by-side view

--- a
+++ b/cnnmodel/dataset.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from torchvision.datasets import DatasetFolder
+
+
+class CNNDataset(Dataset):
+    def __init__(self, root):
+        self.dataset_folder = DatasetFolder(root=root, loader=CNNDataset._npy_loader, extensions=('_mfcc.npy',))
+        self.len_ = len(self.dataset_folder)
+        self.folder_to_index = self.dataset_folder.class_to_idx
+
+    @staticmethod
+    def _npy_loader(path):
+        mfcc = np.load(path)
+        non_mfcc_file_path = path.replace('mfcc', 'other')
+        non_mfcc = np.load(non_mfcc_file_path)
+
+        # in_channels x height x width
+        assert mfcc.shape == (3, 13, 30)
+        assert non_mfcc.shape == (18, )
+
+        mfcc = torch.from_numpy(mfcc).float()
+        non_mfcc = torch.from_numpy(non_mfcc).float()
+
+        return mfcc, non_mfcc, path
+
+    def __getitem__(self, index):
+
+        return self.dataset_folder[index]
+
+    def __len__(self):
+        return self.len_