--- a
+++ b/dataloader.py
@@ -0,0 +1,275 @@
+'''
+https://github.com/akaraspt/deepsleepnet
+Copyright 2017 Akara Supratak and Hao Dong.  All rights reserved.
+'''
+import os
+import numpy as np
+import re
+class SeqDataLoader(object):
+    def __init__(self, data_dir, n_folds, fold_idx,classes):
+        self.data_dir = data_dir
+        self.n_folds = n_folds
+        self.fold_idx = fold_idx
+        self.classes = classes
+
+    def load_npz_file(self, npz_file):
+        """Load data_2013 and labels from a npz file."""
+        with np.load(npz_file) as f:
+            data = f["x"]
+            labels = f["y"]
+            sampling_rate = f["fs"]
+        return data, labels, sampling_rate
+
+    def save_to_npz_file(self, data, labels, sampling_rate, filename):
+
+        # Save
+        save_dict = {
+            "x": data,
+            "y": labels,
+            "fs": sampling_rate,
+
+        }
+        np.savez(filename, **save_dict)
+    def _load_npz_list_files(self, npz_files):
+        """Load data_2013 and labels from list of npz files."""
+        data = []
+        labels = []
+        fs = None
+        for npz_f in npz_files:
+            print ("Loading {} ...".format(npz_f))
+            tmp_data, tmp_labels, self.sampling_rate = self.load_npz_file(npz_f)
+            if fs is None:
+                fs = self.sampling_rate
+            elif fs != self.sampling_rate:
+                raise Exception("Found mismatch in sampling rate.")
+
+            # Reshape the data_2013 to match the input of the model - conv2d
+            tmp_data = np.squeeze(tmp_data)
+            # tmp_data = tmp_data[:, :, np.newaxis, np.newaxis]
+
+            # # Reshape the data_2013 to match the input of the model - conv1d
+            # tmp_data = tmp_data[:, :, np.newaxis]
+
+            # Casting
+            tmp_data = tmp_data.astype(np.float32)
+            tmp_labels = tmp_labels.astype(np.int32)
+
+            # normalize each 30s sample such that each has zero mean and unit vairance
+            tmp_data = (tmp_data - np.expand_dims(tmp_data.mean(axis=1),axis= 1)) / np.expand_dims(tmp_data.std(axis=1),axis=1)
+
+
+            data.append(tmp_data)
+            labels.append(tmp_labels)
+
+        return data, labels
+
+    def _load_cv_data(self, list_files):
+        """Load sequence training and cross-validation sets."""
+        # Split files for training and validation sets
+        val_files = np.array_split(list_files, self.n_folds)
+        train_files = np.setdiff1d(list_files, val_files[self.fold_idx])
+
+        # Load a npz file
+        print ("Load training set:")
+        data_train, label_train = self._load_npz_list_files(train_files)
+        print (" ")
+        print ("Load validation set:")
+        data_val, label_val = self._load_npz_list_files(val_files[self.fold_idx])
+        print (" ")
+
+        return data_train, label_train, data_val, label_val
+
+    def load_test_data(self):
+        # Remove non-mat files, and perform ascending sort
+        allfiles = os.listdir(self.data_dir)
+        npzfiles = []
+        for idx, f in enumerate(allfiles):
+            if ".npz" in f:
+                npzfiles.append(os.path.join(self.data_dir, f))
+        npzfiles.sort()
+
+        # Files for validation sets
+        val_files = np.array_split(npzfiles, self.n_folds)
+        val_files = val_files[self.fold_idx]
+
+        print ("\n========== [Fold-{}] ==========\n".format(self.fold_idx))
+
+        print ("Load validation set:")
+        data_val, label_val = self._load_npz_list_files(val_files)
+
+        return data_val, label_val
+
+    def load_data(self, seq_len = 10, shuffle = True, n_files=None):
+        # Remove non-mat files, and perform ascending sort
+        allfiles = os.listdir(self.data_dir)
+        npzfiles = []
+        for idx, f in enumerate(allfiles):
+            if ".npz" in f:
+                npzfiles.append(os.path.join(self.data_dir, f))
+        npzfiles.sort()
+
+        if n_files is not None:
+            npzfiles = npzfiles[:n_files]
+
+        # subject_files = []
+        # for idx, f in enumerate(allfiles):
+        #     if self.fold_idx < 10:
+        #         pattern = re.compile("[a-zA-Z0-9]*0{}[1-9]E0\.npz$".format(self.fold_idx))
+        #     else:
+        #         pattern = re.compile("[a-zA-Z0-9]*{}[1-9]E0\.npz$".format(self.fold_idx))
+        #     if pattern.match(f):
+        #         subject_files.append(os.path.join(self.data_dir, f))
+
+        # randomize the order of the file names just for one time!
+        r_permute = np.random.permutation(len(npzfiles))
+        filename = "r_permute.npz"
+        if (os.path.isfile(filename)):
+            with np.load(filename) as f:
+                r_permute = f["inds"]
+        else:
+            save_dict = {
+                "inds": r_permute,
+
+            }
+            np.savez(filename, **save_dict)
+
+        npzfiles = np.asarray(npzfiles)[r_permute]
+        train_files = np.array_split(npzfiles, self.n_folds)
+        subject_files = train_files[self.fold_idx]
+
+
+        train_files = list(set(npzfiles) - set(subject_files))
+        # train_files.sort()
+        # subject_files.sort()
+
+        # Load training and validation sets
+        print ("\n========== [Fold-{}] ==========\n".format(self.fold_idx))
+        print ("Load training set:")
+        data_train, label_train = self._load_npz_list_files(train_files)
+        print (" ")
+        print ("Load Test set:")
+        data_test, label_test = self._load_npz_list_files(subject_files)
+        print (" ")
+
+        print ("Training set: n_subjects={}".format(len(data_train)))
+        n_train_examples = 0
+        for d in data_train:
+            print d.shape
+            n_train_examples += d.shape[0]
+        print ("Number of examples = {}".format(n_train_examples))
+        self.print_n_samples_each_class(np.hstack(label_train),self.classes)
+        print (" ")
+        print ("Test set: n_subjects = {}".format(len(data_test)))
+        n_test_examples = 0
+        for d in data_test:
+            print d.shape
+            n_test_examples += d.shape[0]
+        print ("Number of examples = {}".format(n_test_examples))
+        self.print_n_samples_each_class(np.hstack(label_test),self.classes)
+        print (" ")
+
+        data_train = np.vstack(data_train)
+        label_train = np.hstack(label_train)
+        data_train = [data_train[i:i + seq_len] for i in range(0, len(data_train), seq_len)]
+        label_train = [label_train[i:i + seq_len] for i in range(0, len(label_train), seq_len)]
+        if data_train[-1].shape[0]!=seq_len:
+            data_train.pop()
+            label_train.pop()
+
+        data_train = np.asarray(data_train)
+        label_train = np.asarray(label_train)
+
+        data_test = np.vstack(data_test)
+        label_test = np.hstack(label_test)
+        data_test = [data_test[i:i + seq_len] for i in range(0, len(data_test), seq_len)]
+        label_test = [label_test[i:i + seq_len] for i in range(0, len(label_test), seq_len)]
+
+        if data_test[-1].shape[0]!=seq_len:
+            data_test.pop()
+            label_test.pop()
+
+        data_test = np.asarray(data_test)
+        label_test = np.asarray(label_test)
+
+        # shuffle
+        if shuffle is True:
+            # training data_2013
+            permute = np.random.permutation(len(label_train))
+            data_train = np.asarray(data_train)
+            data_train = data_train[permute]
+            label_train = label_train[permute]
+
+            # test data_2013
+            permute = np.random.permutation(len(label_test))
+            data_test = np.asarray(data_test)
+            data_test = data_test[permute]
+            label_test = label_test[permute]
+
+        return data_train, label_train, data_test, label_test
+
+    @staticmethod
+    def load_subject_data(data_dir, subject_idx):
+        # Remove non-mat files, and perform ascending sort
+        allfiles = os.listdir(data_dir)
+        subject_files = []
+        for idx, f in enumerate(allfiles):
+            if subject_idx < 10:
+                pattern = re.compile("[a-zA-Z0-9]*0{}[1-9]E0\.npz$".format(subject_idx))
+            else:
+                pattern = re.compile("[a-zA-Z0-9]*{}[1-9]E0\.npz$".format(subject_idx))
+            if pattern.match(f):
+                subject_files.append(os.path.join(data_dir, f))
+
+        # Files for validation sets
+        if len(subject_files) == 0 or len(subject_files) > 2:
+            raise Exception("Invalid file pattern")
+
+        def load_npz_file(npz_file):
+            """Load data_2013 and labels from a npz file."""
+            with np.load(npz_file) as f:
+                data = f["x"]
+                labels = f["y"]
+                sampling_rate = f["fs"]
+            return data, labels, sampling_rate
+
+        def load_npz_list_files(npz_files):
+            """Load data_2013 and labels from list of npz files."""
+            data = []
+            labels = []
+            fs = None
+            for npz_f in npz_files:
+                print ("Loading {} ...".format(npz_f))
+                tmp_data, tmp_labels, sampling_rate = load_npz_file(npz_f)
+                if fs is None:
+                    fs = sampling_rate
+                elif fs != sampling_rate:
+                    raise Exception("Found mismatch in sampling rate.")
+
+                # Reshape the data_2013 to match the input of the model - conv2d
+                tmp_data = np.squeeze(tmp_data)
+                # tmp_data = tmp_data[:, :, np.newaxis, np.newaxis]
+
+                # # Reshape the data_2013 to match the input of the model - conv1d
+                # tmp_data = tmp_data[:, :, np.newaxis]
+
+                # Casting
+                tmp_data = tmp_data.astype(np.float32)
+                tmp_labels = tmp_labels.astype(np.int32)
+
+                data.append(tmp_data)
+                labels.append(tmp_labels)
+
+            return data, labels
+
+        print ("Load data_2013 from: {}".format(subject_files))
+        data, labels = load_npz_list_files(subject_files)
+
+        return data, labels
+
+    @staticmethod
+    def print_n_samples_each_class(labels,classes):
+        class_dict = dict(zip(range(len(classes)),classes))
+        unique_labels = np.unique(labels)
+        for c in unique_labels:
+            n_samples = len(np.where(labels == c)[0])
+            print ("{}: {}".format(class_dict[c], n_samples))
\ No newline at end of file