--- a +++ b/dataloader.py @@ -0,0 +1,275 @@ +''' +https://github.com/akaraspt/deepsleepnet +Copyright 2017 Akara Supratak and Hao Dong. All rights reserved. +''' +import os +import numpy as np +import re +class SeqDataLoader(object): + def __init__(self, data_dir, n_folds, fold_idx,classes): + self.data_dir = data_dir + self.n_folds = n_folds + self.fold_idx = fold_idx + self.classes = classes + + def load_npz_file(self, npz_file): + """Load data_2013 and labels from a npz file.""" + with np.load(npz_file) as f: + data = f["x"] + labels = f["y"] + sampling_rate = f["fs"] + return data, labels, sampling_rate + + def save_to_npz_file(self, data, labels, sampling_rate, filename): + + # Save + save_dict = { + "x": data, + "y": labels, + "fs": sampling_rate, + + } + np.savez(filename, **save_dict) + def _load_npz_list_files(self, npz_files): + """Load data_2013 and labels from list of npz files.""" + data = [] + labels = [] + fs = None + for npz_f in npz_files: + print ("Loading {} ...".format(npz_f)) + tmp_data, tmp_labels, self.sampling_rate = self.load_npz_file(npz_f) + if fs is None: + fs = self.sampling_rate + elif fs != self.sampling_rate: + raise Exception("Found mismatch in sampling rate.") + + # Reshape the data_2013 to match the input of the model - conv2d + tmp_data = np.squeeze(tmp_data) + # tmp_data = tmp_data[:, :, np.newaxis, np.newaxis] + + # # Reshape the data_2013 to match the input of the model - conv1d + # tmp_data = tmp_data[:, :, np.newaxis] + + # Casting + tmp_data = tmp_data.astype(np.float32) + tmp_labels = tmp_labels.astype(np.int32) + + # normalize each 30s sample such that each has zero mean and unit vairance + tmp_data = (tmp_data - np.expand_dims(tmp_data.mean(axis=1),axis= 1)) / np.expand_dims(tmp_data.std(axis=1),axis=1) + + + data.append(tmp_data) + labels.append(tmp_labels) + + return data, labels + + def _load_cv_data(self, list_files): + """Load sequence training and cross-validation sets.""" + # Split files for training and validation sets + val_files = np.array_split(list_files, self.n_folds) + train_files = np.setdiff1d(list_files, val_files[self.fold_idx]) + + # Load a npz file + print ("Load training set:") + data_train, label_train = self._load_npz_list_files(train_files) + print (" ") + print ("Load validation set:") + data_val, label_val = self._load_npz_list_files(val_files[self.fold_idx]) + print (" ") + + return data_train, label_train, data_val, label_val + + def load_test_data(self): + # Remove non-mat files, and perform ascending sort + allfiles = os.listdir(self.data_dir) + npzfiles = [] + for idx, f in enumerate(allfiles): + if ".npz" in f: + npzfiles.append(os.path.join(self.data_dir, f)) + npzfiles.sort() + + # Files for validation sets + val_files = np.array_split(npzfiles, self.n_folds) + val_files = val_files[self.fold_idx] + + print ("\n========== [Fold-{}] ==========\n".format(self.fold_idx)) + + print ("Load validation set:") + data_val, label_val = self._load_npz_list_files(val_files) + + return data_val, label_val + + def load_data(self, seq_len = 10, shuffle = True, n_files=None): + # Remove non-mat files, and perform ascending sort + allfiles = os.listdir(self.data_dir) + npzfiles = [] + for idx, f in enumerate(allfiles): + if ".npz" in f: + npzfiles.append(os.path.join(self.data_dir, f)) + npzfiles.sort() + + if n_files is not None: + npzfiles = npzfiles[:n_files] + + # subject_files = [] + # for idx, f in enumerate(allfiles): + # if self.fold_idx < 10: + # pattern = re.compile("[a-zA-Z0-9]*0{}[1-9]E0\.npz$".format(self.fold_idx)) + # else: + # pattern = re.compile("[a-zA-Z0-9]*{}[1-9]E0\.npz$".format(self.fold_idx)) + # if pattern.match(f): + # subject_files.append(os.path.join(self.data_dir, f)) + + # randomize the order of the file names just for one time! + r_permute = np.random.permutation(len(npzfiles)) + filename = "r_permute.npz" + if (os.path.isfile(filename)): + with np.load(filename) as f: + r_permute = f["inds"] + else: + save_dict = { + "inds": r_permute, + + } + np.savez(filename, **save_dict) + + npzfiles = np.asarray(npzfiles)[r_permute] + train_files = np.array_split(npzfiles, self.n_folds) + subject_files = train_files[self.fold_idx] + + + train_files = list(set(npzfiles) - set(subject_files)) + # train_files.sort() + # subject_files.sort() + + # Load training and validation sets + print ("\n========== [Fold-{}] ==========\n".format(self.fold_idx)) + print ("Load training set:") + data_train, label_train = self._load_npz_list_files(train_files) + print (" ") + print ("Load Test set:") + data_test, label_test = self._load_npz_list_files(subject_files) + print (" ") + + print ("Training set: n_subjects={}".format(len(data_train))) + n_train_examples = 0 + for d in data_train: + print d.shape + n_train_examples += d.shape[0] + print ("Number of examples = {}".format(n_train_examples)) + self.print_n_samples_each_class(np.hstack(label_train),self.classes) + print (" ") + print ("Test set: n_subjects = {}".format(len(data_test))) + n_test_examples = 0 + for d in data_test: + print d.shape + n_test_examples += d.shape[0] + print ("Number of examples = {}".format(n_test_examples)) + self.print_n_samples_each_class(np.hstack(label_test),self.classes) + print (" ") + + data_train = np.vstack(data_train) + label_train = np.hstack(label_train) + data_train = [data_train[i:i + seq_len] for i in range(0, len(data_train), seq_len)] + label_train = [label_train[i:i + seq_len] for i in range(0, len(label_train), seq_len)] + if data_train[-1].shape[0]!=seq_len: + data_train.pop() + label_train.pop() + + data_train = np.asarray(data_train) + label_train = np.asarray(label_train) + + data_test = np.vstack(data_test) + label_test = np.hstack(label_test) + data_test = [data_test[i:i + seq_len] for i in range(0, len(data_test), seq_len)] + label_test = [label_test[i:i + seq_len] for i in range(0, len(label_test), seq_len)] + + if data_test[-1].shape[0]!=seq_len: + data_test.pop() + label_test.pop() + + data_test = np.asarray(data_test) + label_test = np.asarray(label_test) + + # shuffle + if shuffle is True: + # training data_2013 + permute = np.random.permutation(len(label_train)) + data_train = np.asarray(data_train) + data_train = data_train[permute] + label_train = label_train[permute] + + # test data_2013 + permute = np.random.permutation(len(label_test)) + data_test = np.asarray(data_test) + data_test = data_test[permute] + label_test = label_test[permute] + + return data_train, label_train, data_test, label_test + + @staticmethod + def load_subject_data(data_dir, subject_idx): + # Remove non-mat files, and perform ascending sort + allfiles = os.listdir(data_dir) + subject_files = [] + for idx, f in enumerate(allfiles): + if subject_idx < 10: + pattern = re.compile("[a-zA-Z0-9]*0{}[1-9]E0\.npz$".format(subject_idx)) + else: + pattern = re.compile("[a-zA-Z0-9]*{}[1-9]E0\.npz$".format(subject_idx)) + if pattern.match(f): + subject_files.append(os.path.join(data_dir, f)) + + # Files for validation sets + if len(subject_files) == 0 or len(subject_files) > 2: + raise Exception("Invalid file pattern") + + def load_npz_file(npz_file): + """Load data_2013 and labels from a npz file.""" + with np.load(npz_file) as f: + data = f["x"] + labels = f["y"] + sampling_rate = f["fs"] + return data, labels, sampling_rate + + def load_npz_list_files(npz_files): + """Load data_2013 and labels from list of npz files.""" + data = [] + labels = [] + fs = None + for npz_f in npz_files: + print ("Loading {} ...".format(npz_f)) + tmp_data, tmp_labels, sampling_rate = load_npz_file(npz_f) + if fs is None: + fs = sampling_rate + elif fs != sampling_rate: + raise Exception("Found mismatch in sampling rate.") + + # Reshape the data_2013 to match the input of the model - conv2d + tmp_data = np.squeeze(tmp_data) + # tmp_data = tmp_data[:, :, np.newaxis, np.newaxis] + + # # Reshape the data_2013 to match the input of the model - conv1d + # tmp_data = tmp_data[:, :, np.newaxis] + + # Casting + tmp_data = tmp_data.astype(np.float32) + tmp_labels = tmp_labels.astype(np.int32) + + data.append(tmp_data) + labels.append(tmp_labels) + + return data, labels + + print ("Load data_2013 from: {}".format(subject_files)) + data, labels = load_npz_list_files(subject_files) + + return data, labels + + @staticmethod + def print_n_samples_each_class(labels,classes): + class_dict = dict(zip(range(len(classes)),classes)) + unique_labels = np.unique(labels) + for c in unique_labels: + n_samples = len(np.where(labels == c)[0]) + print ("{}: {}".format(class_dict[c], n_samples)) \ No newline at end of file