Diff of /deepheart/parser.py [000000] .. [d3af21]

Switch to unified view

a b/deepheart/parser.py
1
import os
2
import pickle
3
import numpy as np
4
from scipy.io import wavfile
5
from scipy.fftpack import fft
6
from scipy.signal import butter, lfilter
7
from sklearn.preprocessing import normalize
8
from sklearn.cross_validation import train_test_split
9
from collections import namedtuple
10
from sklearn.cross_validation import check_random_state
11
12
13
class PCG:
14
    """
15
    PCG is a container for loading phonocardiogram (PCG) data for the [2016 physionet
16
    challenge](http://physionet.org/challenge/2016). Raw wav files are parsed into
17
    features, class labels are extracted from header files and data is split into
18
    training and testing groups.
19
    """
20
    def __init__(self, basepath, random_state=42):
21
        self.basepath = basepath
22
        self.class_name_to_id = {"normal": 0, "abnormal": 1}
23
        self.nclasses = len(self.class_name_to_id.keys())
24
25
        self.train = None
26
        self.test = None
27
28
        self.n_samples = 0
29
30
        self.X = None
31
        self.y = None
32
33
        self.random_state = random_state
34
35
    def initialize_wav_data(self):
36
        """
37
        Load the original wav files and extract features. Warning, this can take a
38
        while due to slow FFTs.
39
40
        Parameters
41
        ----------
42
        None
43
44
        Returns
45
        -------
46
        None
47
        """
48
        self.__load_wav_file()
49
        self.__split_train_test()
50
        # TODO: check if directory exists
51
        self.save("/tmp")
52
53
    def save(self, save_path):
54
        """
55
        Persist the PCG class to disk
56
57
        Parameters
58
        ----------
59
        save_path: str
60
            Location on disk to store the parsed PCG metadata
61
62
        Returns
63
        -------
64
        None
65
66
        """
67
        np.save(os.path.join(save_path, "X.npy"), self.X)
68
        np.save(os.path.join(save_path, "y.npy"), self.y)
69
        with open( os.path.join(save_path, "meta"), "w") as fout:
70
            pickle.dump((self.basepath, self.class_name_to_id, self.nclasses,
71
                         self.n_samples, self.random_state), fout)
72
73
    def load(self, load_path):
74
        """
75
        Load a previously stored PCG class.
76
77
        Parameters
78
        ----------
79
        load_path: str
80
            Location on disk to load parsed PCG data
81
82
        Returns
83
        -------
84
        None
85
86
        """
87
        self.X = np.load(os.path.join(load_path, "X.npy"))
88
        self.y = np.load(os.path.join(load_path, "y.npy"))
89
        with open(os.path.join(load_path, "meta"), "r") as fin:
90
            (self.basepath, self.class_name_to_id, self.nclasses,
91
             self.n_samples, self.random_state) = pickle.load(fin)
92
        self.__split_train_test()
93
94
    def __load_wav_file(self, doFFT=True):
95
        """
96
        Loads physio 2016 challenge dataset from self.basepath by crawling the path.
97
        For each discovered wav file:
98
99
        * Attempt to parse the header file for class label
100
        * Attempt to load the wav file
101
        * Calculate features from the wav file. if doFFT, features are
102
        the Fourier transform of the original signal. Else, features are
103
        the raw signal itself truncated to a fixed length
104
105
        Parameters
106
        ----------
107
        doFFT: bool
108
            True if features to be calculated are the FFT of the original signal
109
110
        Returns
111
        -------
112
        None
113
        """
114
115
        # First pass to calculate number of samples
116
        # ensure each wav file has an associated and parsable
117
        # Header file
118
        wav_file_names = []
119
        class_labels = []
120
        for root, dirs, files in os.walk(self.basepath):
121
            # Ignore validation for now!
122
            if "validation" in root:
123
                continue
124
            for file in files:
125
                if file.endswith('.wav'):
126
                    try:
127
                        base_file_name = file.rstrip(".wav")
128
                        label_file_name = os.path.join(root, base_file_name + ".hea")
129
130
                        class_label = self.__parse_class_label(label_file_name)
131
                        class_labels.append(self.class_name_to_id[class_label])
132
                        wav_file_names.append(os.path.join(root, file))
133
134
                        self.n_samples += 1
135
                    except InvalidHeaderFileException as e:
136
                        print e
137
138
        if doFFT:
139
            fft_embedding_size = 400
140
            highpass_embedding_size = 200
141
            X = np.zeros([self.n_samples, fft_embedding_size + highpass_embedding_size])
142
        else:
143
            # Truncating the length of each wav file to the
144
            # min file size (10611) (Note: this is bad
145
            # And causes loss of information!)
146
            embedding_size = 10611
147
            X = np.zeros([self.n_samples, embedding_size])
148
149
        for idx, wavfname in enumerate(wav_file_names):
150
            rate, wf = wavfile.read(wavfname)
151
            wf = normalize(wf.reshape(1, -1))
152
153
            if doFFT:
154
                # We only care about the magnitude of each frequency
155
                wf_fft = np.abs(fft(wf))
156
                wf_fft = wf_fft[:, :fft_embedding_size].reshape(-1)
157
158
                # Filter out high frequencies via Butter transform
159
                # The human heart maxes out around 150bpm = 2.5Hz
160
                # Let's filter out any frequency significantly above this
161
                nyquist = 0.5 * rate
162
                cutoff_freq = 4.0  # Hz
163
                w0, w1 = butter(5, cutoff_freq / nyquist, btype='low', analog=False)
164
                wf_low_pass = lfilter(w0, w1, wf)
165
166
                # FFT the filtered signal
167
                wf_low_pass_fft = np.abs(fft(wf_low_pass))
168
                wf_low_pass_fft = wf_low_pass_fft[:, :highpass_embedding_size].reshape(-1)
169
170
                features = np.concatenate((wf_fft, wf_low_pass_fft))
171
            else:
172
                features = wf[:embedding_size]
173
174
            X[idx, :] = features
175
            idx += 1
176
177
        self.X = X
178
179
        class_labels = np.array(class_labels)
180
181
        # Map from dense to one hot
182
        self.y = np.eye(self.nclasses)[class_labels]
183
184
    def __parse_class_label(self, label_file_name):
185
        """
186
        Parses physio bank header files, where the class label
187
        is located in the last line of the file. An example header
188
        file could contain:
189
190
        f0112 1 2000 60864
191
        f0112.wav 16+44 1 16 0 0 0 0 PCG
192
        # Normal
193
194
195
        Parameters
196
        ----------
197
        label_file_name: str
198
            Path to a specific header file
199
200
        Returns
201
        -------
202
        class_label: str
203
            One of `normal` or `abnormal`
204
        """
205
        with open(label_file_name, 'r') as fin:
206
            header = fin.readlines()
207
208
        comments = [line for line in header if line.startswith("#")]
209
        if not len(comments) == 1:
210
            raise InvalidHeaderFileException("Invalid label file %s" % label_file_name)
211
212
        class_label = str(comments[0]).lstrip("#").rstrip("\r").strip().lower()
213
214
        if not class_label in self.class_name_to_id.keys():
215
            raise InvalidHeaderFileException("Invalid class label %s" % class_label)
216
217
        return class_label
218
219
    def __split_train_test(self):
220
        """
221
        Splits internal features (self.X) and class labels (self.y) into
222
        balanced training and test sets using sklearn's helper function.
223
224
        Notes:
225
         * if self.random_state is None, splits will be randomly seeded
226
           otherwise, self.random_state defines the random seed to deterministicly
227
           split training and test data
228
         * For now, class balancing is done by subsampling the overrepresented class.
229
          Ideally this would be pushed down to the cost function in TensorFlow.
230
231
        Returns
232
        -------
233
        None
234
        """
235
        mlData = namedtuple('ml_data', 'X y')
236
237
        num_pos, num_neg = np.sum(self.y, axis=0)
238
239
        # Remove samples to rebalance classes
240
        # TODO: push this down into the cost function
241
        undersample_rate = num_neg / num_pos
242
        over_represented_idxs = self.y[:, 1] == 0
243
        under_represented_idxs = self.y[:, 1] == 1
244
        random_indexes_to_remove = np.random.rand(self.y.shape[0]) < undersample_rate
245
        sample_idxs = (over_represented_idxs & random_indexes_to_remove |
246
                       under_represented_idxs)
247
248
        X_balanced = self.X[sample_idxs, :]
249
        y_balanced = self.y[sample_idxs, :]
250
251
        X_train, X_test, y_train, y_test = train_test_split(X_balanced, y_balanced, test_size=0.25,
252
                                                            random_state=self.random_state)
253
254
        self.train = mlData(X=X_train, y=y_train)
255
        self.test = mlData(X=X_test, y=y_test)
256
257
    def get_mini_batch(self, batch_size):
258
        """
259
        Helper function for sampling mini-batches from the training
260
        set. Note, random_state needs to be set to None or the same
261
        mini batch will be sampled eternally!
262
263
        Parameters
264
        ----------
265
        batch_size: int
266
            Number of elements to return in the mini batch
267
268
        Returns
269
        -------
270
        X: np.ndarray
271
            A feature matrix subsampled from self.train
272
273
        y: np.ndarray
274
            A one-hot matrix of class labels subsampled from self.train
275
        """
276
        random_state = check_random_state(None)  # self.random_state)
277
        n_training_samples = self.train.X.shape[0]
278
        minibatch_indices = random_state.randint(0, n_training_samples - 1, batch_size)
279
280
        return self.train.X[minibatch_indices, :], self.train.y[minibatch_indices, :]
281
282
283
class InvalidHeaderFileException(Exception):
284
    def __init__(self, *args, **kwargs):
285
        super(args, kwargs)