a b/deeplearn-approach/train_model.py
1
'''
2
This function  function used for training and cross-validating model using. The database is not 
3
included in this repo, please download the CinC Challenge database and truncate/pad data into a 
4
NxM matrix array, being N the number of recordings and M the window accepted by the network (i.e. 
5
30 seconds).
6
7
8
For more information visit: https://github.com/fernandoandreotti/cinc-challenge2017
9
 
10
 Referencing this work
11
   Andreotti, F., Carr, O., Pimentel, M.A.F., Mahdi, A., & De Vos, M. (2017). Comparing Feature Based 
12
   Classifiers and Convolutional Neural Networks to Detect Arrhythmia from Short Segments of ECG. In 
13
   Computing in Cardiology. Rennes (France).
14
--
15
 cinc-challenge2017, version 1.0, Sept 2017
16
 Last updated : 27-09-2017
17
 Released under the GNU General Public License
18
 Copyright (C) 2017  Fernando Andreotti, Oliver Carr, Marco A.F. Pimentel, Adam Mahdi, Maarten De Vos
19
 University of Oxford, Department of Engineering Science, Institute of Biomedical Engineering
20
 fernando.andreotti@eng.ox.ac.uk
21
   
22
 This program is free software: you can redistribute it and/or modify
23
 it under the terms of the GNU General Public License as published by
24
 the Free Software Foundation, either version 3 of the License, or
25
 (at your option) any later version.
26
 
27
 This program is distributed in the hope that it will be useful,
28
 but WITHOUT ANY WARRANTY; without even the implied warranty of
29
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
30
 GNU General Public License for more details.
31
 
32
 You should have received a copy of the GNU General Public License
33
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
34
'''
35
36
import matplotlib.pyplot as plt
37
import tensorflow as tf
38
import numpy as np
39
import scipy.io
40
import gc
41
import itertools
42
from sklearn.metrics import confusion_matrix
43
import sys
44
sys.path.insert(0, './preparation')
45
46
# Keras imports
47
import keras
48
from keras.models import Model
49
from keras.layers import Input, Conv1D, Dense, Flatten, Dropout,MaxPooling1D, Activation, BatchNormalization
50
from keras.callbacks import EarlyStopping, ModelCheckpoint
51
from keras.utils import plot_model
52
from keras import backend as K
53
from keras.callbacks import Callback,warnings
54
55
###################################################################
56
### Callback method for reducing learning rate during training  ###
57
###################################################################
58
class AdvancedLearnignRateScheduler(Callback):    
59
    '''
60
   # Arguments
61
       monitor: quantity to be monitored.
62
       patience: number of epochs with no improvement
63
           after which training will be stopped.
64
       verbose: verbosity mode.
65
       mode: one of {auto, min, max}. In 'min' mode,
66
           training will stop when the quantity
67
           monitored has stopped decreasing; in 'max'
68
           mode it will stop when the quantity
69
           monitored has stopped increasing.
70
   '''
71
    def __init__(self, monitor='val_loss', patience=0,verbose=0, mode='auto', decayRatio=0.1):
72
        super(Callback, self).__init__() 
73
        self.monitor = monitor
74
        self.patience = patience
75
        self.verbose = verbose
76
        self.wait = 0
77
        self.decayRatio = decayRatio
78
 
79
        if mode not in ['auto', 'min', 'max']:
80
            warnings.warn('Mode %s is unknown, '
81
                          'fallback to auto mode.'
82
                          % (self.mode), RuntimeWarning)
83
            mode = 'auto'
84
 
85
        if mode == 'min':
86
            self.monitor_op = np.less
87
            self.best = np.Inf
88
        elif mode == 'max':
89
            self.monitor_op = np.greater
90
            self.best = -np.Inf
91
        else:
92
            if 'acc' in self.monitor:
93
                self.monitor_op = np.greater
94
                self.best = -np.Inf
95
            else:
96
                self.monitor_op = np.less
97
                self.best = np.Inf
98
 
99
    def on_epoch_end(self, epoch, logs={}):
100
        current = logs.get(self.monitor)
101
        current_lr = K.get_value(self.model.optimizer.lr)
102
        print("\nLearning rate:", current_lr)
103
        if current is None:
104
            warnings.warn('AdvancedLearnignRateScheduler'
105
                          ' requires %s available!' %
106
                          (self.monitor), RuntimeWarning)
107
 
108
        if self.monitor_op(current, self.best):
109
            self.best = current
110
            self.wait = 0
111
        else:
112
            if self.wait >= self.patience:
113
                if self.verbose > 0:
114
                    print('\nEpoch %05d: reducing learning rate' % (epoch))
115
                    assert hasattr(self.model.optimizer, 'lr'), \
116
                        'Optimizer must have a "lr" attribute.'
117
                    current_lr = K.get_value(self.model.optimizer.lr)
118
                    new_lr = current_lr * self.decayRatio
119
                    K.set_value(self.model.optimizer.lr, new_lr)
120
                    self.wait = 0 
121
            self.wait += 1
122
123
124
###########################################
125
## Function to plot confusion matrices  ##
126
#########################################
127
def plot_confusion_matrix(cm, classes,
128
                          normalize=False,
129
                          title='Confusion matrix',
130
                          cmap=plt.cm.Blues):
131
    """
132
    This function prints and plots the confusion matrix.
133
    Normalization can be applied by setting `normalize=True`.
134
    """
135
    if normalize:
136
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
137
        print("Normalized confusion matrix")
138
    else:
139
        print('Confusion matrix, without normalization')
140
    cm = np.around(cm, decimals=3)
141
    print(cm)
142
143
    thresh = cm.max() / 2.
144
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
145
        plt.text(j, i, cm[i, j],
146
                 horizontalalignment="center",
147
                 color="white" if cm[i, j] > thresh else "black")
148
        
149
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
150
    plt.title(title)
151
    plt.colorbar()
152
    tick_marks = np.arange(len(classes))
153
    plt.xticks(tick_marks, classes, rotation=45)
154
    plt.yticks(tick_marks, classes)
155
    plt.tight_layout()
156
    plt.ylabel('True label')
157
    plt.xlabel('Predicted label')
158
    plt.savefig('confusion.eps', format='eps', dpi=1000)
159
160
161
#####################################
162
## Model definition              ##
163
## ResNet based on Rajpurkar    ##
164
################################## 
165
def ResNet_model(WINDOW_SIZE):
166
    # Add CNN layers left branch (higher frequencies)
167
    # Parameters from paper
168
    INPUT_FEAT = 1
169
    OUTPUT_CLASS = 4    # output classes
170
171
    k = 1    # increment every 4th residual block
172
    p = True # pool toggle every other residual block (end with 2^8)
173
    convfilt = 64
174
    convstr = 1
175
    ksize = 16
176
    poolsize = 2
177
    poolstr  = 2
178
    drop = 0.5
179
    
180
    # Modelling with Functional API
181
    #input1 = Input(shape=(None,1), name='input')
182
    input1 = Input(shape=(WINDOW_SIZE,INPUT_FEAT), name='input')
183
    
184
    ## First convolutional block (conv,BN, relu)
185
    x = Conv1D(filters=convfilt,
186
               kernel_size=ksize,
187
               padding='same',
188
               strides=convstr,
189
               kernel_initializer='he_normal')(input1)                
190
    x = BatchNormalization()(x)        
191
    x = Activation('relu')(x)  
192
    
193
    ## Second convolutional block (conv, BN, relu, dropout, conv) with residual net
194
    # Left branch (convolutions)
195
    x1 =  Conv1D(filters=convfilt,
196
               kernel_size=ksize,
197
               padding='same',
198
               strides=convstr,
199
               kernel_initializer='he_normal')(x)      
200
    x1 = BatchNormalization()(x1)    
201
    x1 = Activation('relu')(x1)
202
    x1 = Dropout(drop)(x1)
203
    x1 =  Conv1D(filters=convfilt,
204
               kernel_size=ksize,
205
               padding='same',
206
               strides=convstr,
207
               kernel_initializer='he_normal')(x1)
208
    x1 = MaxPooling1D(pool_size=poolsize,
209
                      strides=poolstr)(x1)
210
    # Right branch, shortcut branch pooling
211
    x2 = MaxPooling1D(pool_size=poolsize,
212
                      strides=poolstr)(x)
213
    # Merge both branches
214
    x = keras.layers.add([x1, x2])
215
    del x1,x2
216
    
217
    ## Main loop
218
    p = not p 
219
    for l in range(15):
220
        
221
        if (l%4 == 0) and (l>0): # increment k on every fourth residual block
222
            k += 1
223
             # increase depth by 1x1 Convolution case dimension shall change
224
            xshort = Conv1D(filters=convfilt*k,kernel_size=1)(x)
225
        else:
226
            xshort = x        
227
        # Left branch (convolutions)
228
        # notice the ordering of the operations has changed        
229
        x1 = BatchNormalization()(x)
230
        x1 = Activation('relu')(x1)
231
        x1 = Dropout(drop)(x1)
232
        x1 =  Conv1D(filters=convfilt*k,
233
               kernel_size=ksize,
234
               padding='same',
235
               strides=convstr,
236
               kernel_initializer='he_normal')(x1)        
237
        x1 = BatchNormalization()(x1)
238
        x1 = Activation('relu')(x1)
239
        x1 = Dropout(drop)(x1)
240
        x1 =  Conv1D(filters=convfilt*k,
241
               kernel_size=ksize,
242
               padding='same',
243
               strides=convstr,
244
               kernel_initializer='he_normal')(x1)        
245
        if p:
246
            x1 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(x1)                
247
248
        # Right branch: shortcut connection
249
        if p:
250
            x2 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(xshort)
251
        else:
252
            x2 = xshort  # pool or identity            
253
        # Merging branches
254
        x = keras.layers.add([x1, x2])
255
        # change parameters
256
        p = not p # toggle pooling
257
258
    
259
    # Final bit    
260
    x = BatchNormalization()(x)
261
    x = Activation('relu')(x) 
262
    x = Flatten()(x)
263
    #x = Dense(1000)(x)
264
    #x = Dense(1000)(x)
265
    out = Dense(OUTPUT_CLASS, activation='softmax')(x)
266
    model = Model(inputs=input1, outputs=out)
267
    model.compile(optimizer='adam',
268
                  loss='categorical_crossentropy',
269
                  metrics=['accuracy'])
270
    #model.summary()
271
    #sequential_model_to_ascii_printout(model)
272
    plot_model(model, to_file='model.png')
273
    return model
274
275
###########################################################
276
## Function to perform K-fold Crossvalidation on model  ##
277
##########################################################
278
def model_eval(X,y):
279
    batch =64
280
    epochs = 20  
281
    rep = 1         # K fold procedure can be repeated multiple times
282
    Kfold = 5
283
    Ntrain = 8528 # number of recordings on training set
284
    Nsamp = int(Ntrain/Kfold) # number of recordings to take as validation        
285
   
286
    # Need to add dimension for training
287
    X = np.expand_dims(X, axis=2)
288
    classes = ['A', 'N', 'O', '~']
289
    Nclass = len(classes)
290
    cvconfusion = np.zeros((Nclass,Nclass,Kfold*rep))
291
    cvscores = []       
292
    counter = 0
293
    # repetitions of cross validation
294
    for r in range(rep):
295
        print("Rep %d"%(r+1))
296
        # cross validation loop
297
        for k in range(Kfold):
298
            print("Cross-validation run %d"%(k+1))
299
            # Callbacks definition
300
            callbacks = [
301
                # Early stopping definition
302
                EarlyStopping(monitor='val_loss', patience=3, verbose=1),
303
                # Decrease learning rate by 0.1 factor
304
                AdvancedLearnignRateScheduler(monitor='val_loss', patience=1,verbose=1, mode='auto', decayRatio=0.1),            
305
                # Saving best model
306
                ModelCheckpoint('weights-best_k{}_r{}.hdf5'.format(k,r), monitor='val_loss', save_best_only=True, verbose=1),
307
                ]
308
            # Load model
309
            model = ResNet_model(WINDOW_SIZE)
310
            
311
            # split train and validation sets
312
            idxval = np.random.choice(Ntrain, Nsamp,replace=False)
313
            idxtrain = np.invert(np.in1d(range(X_train.shape[0]),idxval))
314
            ytrain = y[np.asarray(idxtrain),:]
315
            Xtrain = X[np.asarray(idxtrain),:,:]         
316
            Xval = X[np.asarray(idxval),:,:]
317
            yval = y[np.asarray(idxval),:]
318
            
319
            # Train model
320
            model.fit(Xtrain, ytrain,
321
                      validation_data=(Xval, yval),
322
                      epochs=epochs, batch_size=batch,callbacks=callbacks)
323
            
324
            # Evaluate best trained model
325
            model.load_weights('weights-best_k{}_r{}.hdf5'.format(k,r))
326
            ypred = model.predict(Xval)
327
            ypred = np.argmax(ypred,axis=1)
328
            ytrue = np.argmax(yval,axis=1)
329
            cvconfusion[:,:,counter] = confusion_matrix(ytrue, ypred)
330
            F1 = np.zeros((4,1))
331
            for i in range(4):
332
                F1[i]=2*cvconfusion[i,i,counter]/(np.sum(cvconfusion[i,:,counter])+np.sum(cvconfusion[:,i,counter]))
333
                print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0]))            
334
            cvscores.append(np.mean(F1)* 100)
335
            print("Overall F1 measure: {:1.4f}".format(np.mean(F1)))            
336
            K.clear_session()
337
            gc.collect()
338
            config = tf.ConfigProto()
339
            config.gpu_options.allow_growth=True            
340
            sess = tf.Session(config=config)
341
            K.set_session(sess)
342
            counter += 1
343
    # Saving cross validation results 
344
    scipy.io.savemat('xval_results.mat',mdict={'cvconfusion': cvconfusion.tolist()})  
345
    return model
346
347
###########################
348
## Function to load data ##
349
###########################
350
def loaddata(WINDOW_SIZE):    
351
    '''
352
        Load training/test data into workspace
353
        
354
        This function assumes you have downloaded and padded/truncated the 
355
        training set into a local file named "trainingset.mat". This file should 
356
        contain the following structures:
357
            - trainset: NxM matrix of N ECG segments with length M
358
            - traintarget: Nx4 matrix of coded labels where each column contains
359
            one in case it matches ['A', 'N', 'O', '~'].
360
        
361
    '''
362
    print("Loading data training set")        
363
    matfile = scipy.io.loadmat('trainingset.mat')
364
    X = matfile['trainset']
365
    y = matfile['traintarget']
366
    
367
    # Merging datasets    
368
    # Case other sets are available, load them then concatenate
369
    #y = np.concatenate((traintarget,augtarget),axis=0)     
370
    #X = np.concatenate((trainset,augset),axis=0)     
371
372
    X =  X[:,0:WINDOW_SIZE] 
373
    return (X, y)
374
375
376
#####################
377
# Main function   ##
378
###################
379
380
config = tf.ConfigProto(allow_soft_placement=True)
381
config.gpu_options.allow_growth = True
382
sess = tf.Session(config=config)
383
seed = 7
384
np.random.seed(seed)
385
386
# Parameters
387
FS = 300
388
WINDOW_SIZE = 30*FS     # padding window for CNN
389
390
# Loading data
391
(X_train,y_train) = loaddata(WINDOW_SIZE)
392
393
# Training model
394
model = model_eval(X_train,y_train)
395
396
# Outputing results of cross validation
397
matfile = scipy.io.loadmat('xval_results.mat')
398
cv = matfile['cvconfusion']
399
F1mean = np.zeros(cv.shape[2])
400
for j in range(cv.shape[2]):
401
    classes = ['A', 'N', 'O', '~']
402
    F1 = np.zeros((4,1))
403
    for i in range(4):
404
        F1[i]=2*cv[i,i,j]/(np.sum(cv[i,:,j])+np.sum(cv[:,i,j]))        
405
        print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0]))
406
    F1mean[j] = np.mean(F1)
407
    print("mean F1 measure for: {:1.4f}".format(F1mean[j]))
408
print("Overall F1 : {:1.4f}".format(np.mean(F1mean)))
409
# Plotting confusion matrix
410
cvsum = np.sum(cv,axis=2)
411
for i in range(4):
412
    F1[i]=2*cvsum[i,i]/(np.sum(cvsum[i,:])+np.sum(cvsum[:,i]))        
413
    print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0]))
414
F1mean = np.mean(F1)
415
print("mean F1 measure for: {:1.4f}".format(F1mean))
416
plot_confusion_matrix(cvsum, classes,normalize=True,title='Confusion matrix')
417
418