Diff of /dataloader.py [000000] .. [3f9044]

Switch to unified view

a b/dataloader.py
1
'''
2
https://github.com/akaraspt/deepsleepnet
3
Copyright 2017 Akara Supratak and Hao Dong.  All rights reserved.
4
'''
5
import os
6
import numpy as np
7
import re
8
class SeqDataLoader(object):
9
    def __init__(self, data_dir, n_folds, fold_idx,classes):
10
        self.data_dir = data_dir
11
        self.n_folds = n_folds
12
        self.fold_idx = fold_idx
13
        self.classes = classes
14
15
    def load_npz_file(self, npz_file):
16
        """Load data_2013 and labels from a npz file."""
17
        with np.load(npz_file) as f:
18
            data = f["x"]
19
            labels = f["y"]
20
            sampling_rate = f["fs"]
21
        return data, labels, sampling_rate
22
23
    def save_to_npz_file(self, data, labels, sampling_rate, filename):
24
25
        # Save
26
        save_dict = {
27
            "x": data,
28
            "y": labels,
29
            "fs": sampling_rate,
30
31
        }
32
        np.savez(filename, **save_dict)
33
    def _load_npz_list_files(self, npz_files):
34
        """Load data_2013 and labels from list of npz files."""
35
        data = []
36
        labels = []
37
        fs = None
38
        for npz_f in npz_files:
39
            print ("Loading {} ...".format(npz_f))
40
            tmp_data, tmp_labels, self.sampling_rate = self.load_npz_file(npz_f)
41
            if fs is None:
42
                fs = self.sampling_rate
43
            elif fs != self.sampling_rate:
44
                raise Exception("Found mismatch in sampling rate.")
45
46
            # Reshape the data_2013 to match the input of the model - conv2d
47
            tmp_data = np.squeeze(tmp_data)
48
            # tmp_data = tmp_data[:, :, np.newaxis, np.newaxis]
49
50
            # # Reshape the data_2013 to match the input of the model - conv1d
51
            # tmp_data = tmp_data[:, :, np.newaxis]
52
53
            # Casting
54
            tmp_data = tmp_data.astype(np.float32)
55
            tmp_labels = tmp_labels.astype(np.int32)
56
57
            # normalize each 30s sample such that each has zero mean and unit vairance
58
            tmp_data = (tmp_data - np.expand_dims(tmp_data.mean(axis=1),axis= 1)) / np.expand_dims(tmp_data.std(axis=1),axis=1)
59
60
61
            data.append(tmp_data)
62
            labels.append(tmp_labels)
63
64
        return data, labels
65
66
    def _load_cv_data(self, list_files):
67
        """Load sequence training and cross-validation sets."""
68
        # Split files for training and validation sets
69
        val_files = np.array_split(list_files, self.n_folds)
70
        train_files = np.setdiff1d(list_files, val_files[self.fold_idx])
71
72
        # Load a npz file
73
        print ("Load training set:")
74
        data_train, label_train = self._load_npz_list_files(train_files)
75
        print (" ")
76
        print ("Load validation set:")
77
        data_val, label_val = self._load_npz_list_files(val_files[self.fold_idx])
78
        print (" ")
79
80
        return data_train, label_train, data_val, label_val
81
82
    def load_test_data(self):
83
        # Remove non-mat files, and perform ascending sort
84
        allfiles = os.listdir(self.data_dir)
85
        npzfiles = []
86
        for idx, f in enumerate(allfiles):
87
            if ".npz" in f:
88
                npzfiles.append(os.path.join(self.data_dir, f))
89
        npzfiles.sort()
90
91
        # Files for validation sets
92
        val_files = np.array_split(npzfiles, self.n_folds)
93
        val_files = val_files[self.fold_idx]
94
95
        print ("\n========== [Fold-{}] ==========\n".format(self.fold_idx))
96
97
        print ("Load validation set:")
98
        data_val, label_val = self._load_npz_list_files(val_files)
99
100
        return data_val, label_val
101
102
    def load_data(self, seq_len = 10, shuffle = True, n_files=None):
103
        # Remove non-mat files, and perform ascending sort
104
        allfiles = os.listdir(self.data_dir)
105
        npzfiles = []
106
        for idx, f in enumerate(allfiles):
107
            if ".npz" in f:
108
                npzfiles.append(os.path.join(self.data_dir, f))
109
        npzfiles.sort()
110
111
        if n_files is not None:
112
            npzfiles = npzfiles[:n_files]
113
114
        # subject_files = []
115
        # for idx, f in enumerate(allfiles):
116
        #     if self.fold_idx < 10:
117
        #         pattern = re.compile("[a-zA-Z0-9]*0{}[1-9]E0\.npz$".format(self.fold_idx))
118
        #     else:
119
        #         pattern = re.compile("[a-zA-Z0-9]*{}[1-9]E0\.npz$".format(self.fold_idx))
120
        #     if pattern.match(f):
121
        #         subject_files.append(os.path.join(self.data_dir, f))
122
123
        # randomize the order of the file names just for one time!
124
        r_permute = np.random.permutation(len(npzfiles))
125
        filename = "r_permute.npz"
126
        if (os.path.isfile(filename)):
127
            with np.load(filename) as f:
128
                r_permute = f["inds"]
129
        else:
130
            save_dict = {
131
                "inds": r_permute,
132
133
            }
134
            np.savez(filename, **save_dict)
135
136
        npzfiles = np.asarray(npzfiles)[r_permute]
137
        train_files = np.array_split(npzfiles, self.n_folds)
138
        subject_files = train_files[self.fold_idx]
139
140
141
        train_files = list(set(npzfiles) - set(subject_files))
142
        # train_files.sort()
143
        # subject_files.sort()
144
145
        # Load training and validation sets
146
        print ("\n========== [Fold-{}] ==========\n".format(self.fold_idx))
147
        print ("Load training set:")
148
        data_train, label_train = self._load_npz_list_files(train_files)
149
        print (" ")
150
        print ("Load Test set:")
151
        data_test, label_test = self._load_npz_list_files(subject_files)
152
        print (" ")
153
154
        print ("Training set: n_subjects={}".format(len(data_train)))
155
        n_train_examples = 0
156
        for d in data_train:
157
            print d.shape
158
            n_train_examples += d.shape[0]
159
        print ("Number of examples = {}".format(n_train_examples))
160
        self.print_n_samples_each_class(np.hstack(label_train),self.classes)
161
        print (" ")
162
        print ("Test set: n_subjects = {}".format(len(data_test)))
163
        n_test_examples = 0
164
        for d in data_test:
165
            print d.shape
166
            n_test_examples += d.shape[0]
167
        print ("Number of examples = {}".format(n_test_examples))
168
        self.print_n_samples_each_class(np.hstack(label_test),self.classes)
169
        print (" ")
170
171
        data_train = np.vstack(data_train)
172
        label_train = np.hstack(label_train)
173
        data_train = [data_train[i:i + seq_len] for i in range(0, len(data_train), seq_len)]
174
        label_train = [label_train[i:i + seq_len] for i in range(0, len(label_train), seq_len)]
175
        if data_train[-1].shape[0]!=seq_len:
176
            data_train.pop()
177
            label_train.pop()
178
179
        data_train = np.asarray(data_train)
180
        label_train = np.asarray(label_train)
181
182
        data_test = np.vstack(data_test)
183
        label_test = np.hstack(label_test)
184
        data_test = [data_test[i:i + seq_len] for i in range(0, len(data_test), seq_len)]
185
        label_test = [label_test[i:i + seq_len] for i in range(0, len(label_test), seq_len)]
186
187
        if data_test[-1].shape[0]!=seq_len:
188
            data_test.pop()
189
            label_test.pop()
190
191
        data_test = np.asarray(data_test)
192
        label_test = np.asarray(label_test)
193
194
        # shuffle
195
        if shuffle is True:
196
            # training data_2013
197
            permute = np.random.permutation(len(label_train))
198
            data_train = np.asarray(data_train)
199
            data_train = data_train[permute]
200
            label_train = label_train[permute]
201
202
            # test data_2013
203
            permute = np.random.permutation(len(label_test))
204
            data_test = np.asarray(data_test)
205
            data_test = data_test[permute]
206
            label_test = label_test[permute]
207
208
        return data_train, label_train, data_test, label_test
209
210
    @staticmethod
211
    def load_subject_data(data_dir, subject_idx):
212
        # Remove non-mat files, and perform ascending sort
213
        allfiles = os.listdir(data_dir)
214
        subject_files = []
215
        for idx, f in enumerate(allfiles):
216
            if subject_idx < 10:
217
                pattern = re.compile("[a-zA-Z0-9]*0{}[1-9]E0\.npz$".format(subject_idx))
218
            else:
219
                pattern = re.compile("[a-zA-Z0-9]*{}[1-9]E0\.npz$".format(subject_idx))
220
            if pattern.match(f):
221
                subject_files.append(os.path.join(data_dir, f))
222
223
        # Files for validation sets
224
        if len(subject_files) == 0 or len(subject_files) > 2:
225
            raise Exception("Invalid file pattern")
226
227
        def load_npz_file(npz_file):
228
            """Load data_2013 and labels from a npz file."""
229
            with np.load(npz_file) as f:
230
                data = f["x"]
231
                labels = f["y"]
232
                sampling_rate = f["fs"]
233
            return data, labels, sampling_rate
234
235
        def load_npz_list_files(npz_files):
236
            """Load data_2013 and labels from list of npz files."""
237
            data = []
238
            labels = []
239
            fs = None
240
            for npz_f in npz_files:
241
                print ("Loading {} ...".format(npz_f))
242
                tmp_data, tmp_labels, sampling_rate = load_npz_file(npz_f)
243
                if fs is None:
244
                    fs = sampling_rate
245
                elif fs != sampling_rate:
246
                    raise Exception("Found mismatch in sampling rate.")
247
248
                # Reshape the data_2013 to match the input of the model - conv2d
249
                tmp_data = np.squeeze(tmp_data)
250
                # tmp_data = tmp_data[:, :, np.newaxis, np.newaxis]
251
252
                # # Reshape the data_2013 to match the input of the model - conv1d
253
                # tmp_data = tmp_data[:, :, np.newaxis]
254
255
                # Casting
256
                tmp_data = tmp_data.astype(np.float32)
257
                tmp_labels = tmp_labels.astype(np.int32)
258
259
                data.append(tmp_data)
260
                labels.append(tmp_labels)
261
262
            return data, labels
263
264
        print ("Load data_2013 from: {}".format(subject_files))
265
        data, labels = load_npz_list_files(subject_files)
266
267
        return data, labels
268
269
    @staticmethod
270
    def print_n_samples_each_class(labels,classes):
271
        class_dict = dict(zip(range(len(classes)),classes))
272
        unique_labels = np.unique(labels)
273
        for c in unique_labels:
274
            n_samples = len(np.where(labels == c)[0])
275
            print ("{}: {}".format(class_dict[c], n_samples))