Switch to side-by-side view

--- a
+++ b/deeplearn-approach/train_model.py
@@ -0,0 +1,418 @@
+'''
+This function  function used for training and cross-validating model using. The database is not 
+included in this repo, please download the CinC Challenge database and truncate/pad data into a 
+NxM matrix array, being N the number of recordings and M the window accepted by the network (i.e. 
+30 seconds).
+
+
+For more information visit: https://github.com/fernandoandreotti/cinc-challenge2017
+ 
+ Referencing this work
+   Andreotti, F., Carr, O., Pimentel, M.A.F., Mahdi, A., & De Vos, M. (2017). Comparing Feature Based 
+   Classifiers and Convolutional Neural Networks to Detect Arrhythmia from Short Segments of ECG. In 
+   Computing in Cardiology. Rennes (France).
+--
+ cinc-challenge2017, version 1.0, Sept 2017
+ Last updated : 27-09-2017
+ Released under the GNU General Public License
+ Copyright (C) 2017  Fernando Andreotti, Oliver Carr, Marco A.F. Pimentel, Adam Mahdi, Maarten De Vos
+ University of Oxford, Department of Engineering Science, Institute of Biomedical Engineering
+ fernando.andreotti@eng.ox.ac.uk
+   
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+ 
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ GNU General Public License for more details.
+ 
+ You should have received a copy of the GNU General Public License
+ along with this program.  If not, see <http://www.gnu.org/licenses/>.
+'''
+
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import numpy as np
+import scipy.io
+import gc
+import itertools
+from sklearn.metrics import confusion_matrix
+import sys
+sys.path.insert(0, './preparation')
+
+# Keras imports
+import keras
+from keras.models import Model
+from keras.layers import Input, Conv1D, Dense, Flatten, Dropout,MaxPooling1D, Activation, BatchNormalization
+from keras.callbacks import EarlyStopping, ModelCheckpoint
+from keras.utils import plot_model
+from keras import backend as K
+from keras.callbacks import Callback,warnings
+
+###################################################################
+### Callback method for reducing learning rate during training  ###
+###################################################################
+class AdvancedLearnignRateScheduler(Callback):    
+    '''
+   # Arguments
+       monitor: quantity to be monitored.
+       patience: number of epochs with no improvement
+           after which training will be stopped.
+       verbose: verbosity mode.
+       mode: one of {auto, min, max}. In 'min' mode,
+           training will stop when the quantity
+           monitored has stopped decreasing; in 'max'
+           mode it will stop when the quantity
+           monitored has stopped increasing.
+   '''
+    def __init__(self, monitor='val_loss', patience=0,verbose=0, mode='auto', decayRatio=0.1):
+        super(Callback, self).__init__() 
+        self.monitor = monitor
+        self.patience = patience
+        self.verbose = verbose
+        self.wait = 0
+        self.decayRatio = decayRatio
+ 
+        if mode not in ['auto', 'min', 'max']:
+            warnings.warn('Mode %s is unknown, '
+                          'fallback to auto mode.'
+                          % (self.mode), RuntimeWarning)
+            mode = 'auto'
+ 
+        if mode == 'min':
+            self.monitor_op = np.less
+            self.best = np.Inf
+        elif mode == 'max':
+            self.monitor_op = np.greater
+            self.best = -np.Inf
+        else:
+            if 'acc' in self.monitor:
+                self.monitor_op = np.greater
+                self.best = -np.Inf
+            else:
+                self.monitor_op = np.less
+                self.best = np.Inf
+ 
+    def on_epoch_end(self, epoch, logs={}):
+        current = logs.get(self.monitor)
+        current_lr = K.get_value(self.model.optimizer.lr)
+        print("\nLearning rate:", current_lr)
+        if current is None:
+            warnings.warn('AdvancedLearnignRateScheduler'
+                          ' requires %s available!' %
+                          (self.monitor), RuntimeWarning)
+ 
+        if self.monitor_op(current, self.best):
+            self.best = current
+            self.wait = 0
+        else:
+            if self.wait >= self.patience:
+                if self.verbose > 0:
+                    print('\nEpoch %05d: reducing learning rate' % (epoch))
+                    assert hasattr(self.model.optimizer, 'lr'), \
+                        'Optimizer must have a "lr" attribute.'
+                    current_lr = K.get_value(self.model.optimizer.lr)
+                    new_lr = current_lr * self.decayRatio
+                    K.set_value(self.model.optimizer.lr, new_lr)
+                    self.wait = 0 
+            self.wait += 1
+
+
+###########################################
+## Function to plot confusion matrices  ##
+#########################################
+def plot_confusion_matrix(cm, classes,
+                          normalize=False,
+                          title='Confusion matrix',
+                          cmap=plt.cm.Blues):
+    """
+    This function prints and plots the confusion matrix.
+    Normalization can be applied by setting `normalize=True`.
+    """
+    if normalize:
+        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
+        print("Normalized confusion matrix")
+    else:
+        print('Confusion matrix, without normalization')
+    cm = np.around(cm, decimals=3)
+    print(cm)
+
+    thresh = cm.max() / 2.
+    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
+        plt.text(j, i, cm[i, j],
+                 horizontalalignment="center",
+                 color="white" if cm[i, j] > thresh else "black")
+        
+    plt.imshow(cm, interpolation='nearest', cmap=cmap)
+    plt.title(title)
+    plt.colorbar()
+    tick_marks = np.arange(len(classes))
+    plt.xticks(tick_marks, classes, rotation=45)
+    plt.yticks(tick_marks, classes)
+    plt.tight_layout()
+    plt.ylabel('True label')
+    plt.xlabel('Predicted label')
+    plt.savefig('confusion.eps', format='eps', dpi=1000)
+
+
+#####################################
+## Model definition              ##
+## ResNet based on Rajpurkar    ##
+################################## 
+def ResNet_model(WINDOW_SIZE):
+    # Add CNN layers left branch (higher frequencies)
+    # Parameters from paper
+    INPUT_FEAT = 1
+    OUTPUT_CLASS = 4    # output classes
+
+    k = 1    # increment every 4th residual block
+    p = True # pool toggle every other residual block (end with 2^8)
+    convfilt = 64
+    convstr = 1
+    ksize = 16
+    poolsize = 2
+    poolstr  = 2
+    drop = 0.5
+    
+    # Modelling with Functional API
+    #input1 = Input(shape=(None,1), name='input')
+    input1 = Input(shape=(WINDOW_SIZE,INPUT_FEAT), name='input')
+    
+    ## First convolutional block (conv,BN, relu)
+    x = Conv1D(filters=convfilt,
+               kernel_size=ksize,
+               padding='same',
+               strides=convstr,
+               kernel_initializer='he_normal')(input1)                
+    x = BatchNormalization()(x)        
+    x = Activation('relu')(x)  
+    
+    ## Second convolutional block (conv, BN, relu, dropout, conv) with residual net
+    # Left branch (convolutions)
+    x1 =  Conv1D(filters=convfilt,
+               kernel_size=ksize,
+               padding='same',
+               strides=convstr,
+               kernel_initializer='he_normal')(x)      
+    x1 = BatchNormalization()(x1)    
+    x1 = Activation('relu')(x1)
+    x1 = Dropout(drop)(x1)
+    x1 =  Conv1D(filters=convfilt,
+               kernel_size=ksize,
+               padding='same',
+               strides=convstr,
+               kernel_initializer='he_normal')(x1)
+    x1 = MaxPooling1D(pool_size=poolsize,
+                      strides=poolstr)(x1)
+    # Right branch, shortcut branch pooling
+    x2 = MaxPooling1D(pool_size=poolsize,
+                      strides=poolstr)(x)
+    # Merge both branches
+    x = keras.layers.add([x1, x2])
+    del x1,x2
+    
+    ## Main loop
+    p = not p 
+    for l in range(15):
+        
+        if (l%4 == 0) and (l>0): # increment k on every fourth residual block
+            k += 1
+             # increase depth by 1x1 Convolution case dimension shall change
+            xshort = Conv1D(filters=convfilt*k,kernel_size=1)(x)
+        else:
+            xshort = x        
+        # Left branch (convolutions)
+        # notice the ordering of the operations has changed        
+        x1 = BatchNormalization()(x)
+        x1 = Activation('relu')(x1)
+        x1 = Dropout(drop)(x1)
+        x1 =  Conv1D(filters=convfilt*k,
+               kernel_size=ksize,
+               padding='same',
+               strides=convstr,
+               kernel_initializer='he_normal')(x1)        
+        x1 = BatchNormalization()(x1)
+        x1 = Activation('relu')(x1)
+        x1 = Dropout(drop)(x1)
+        x1 =  Conv1D(filters=convfilt*k,
+               kernel_size=ksize,
+               padding='same',
+               strides=convstr,
+               kernel_initializer='he_normal')(x1)        
+        if p:
+            x1 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(x1)                
+
+        # Right branch: shortcut connection
+        if p:
+            x2 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(xshort)
+        else:
+            x2 = xshort  # pool or identity            
+        # Merging branches
+        x = keras.layers.add([x1, x2])
+        # change parameters
+        p = not p # toggle pooling
+
+    
+    # Final bit    
+    x = BatchNormalization()(x)
+    x = Activation('relu')(x) 
+    x = Flatten()(x)
+    #x = Dense(1000)(x)
+    #x = Dense(1000)(x)
+    out = Dense(OUTPUT_CLASS, activation='softmax')(x)
+    model = Model(inputs=input1, outputs=out)
+    model.compile(optimizer='adam',
+                  loss='categorical_crossentropy',
+                  metrics=['accuracy'])
+    #model.summary()
+    #sequential_model_to_ascii_printout(model)
+    plot_model(model, to_file='model.png')
+    return model
+
+###########################################################
+## Function to perform K-fold Crossvalidation on model  ##
+##########################################################
+def model_eval(X,y):
+    batch =64
+    epochs = 20  
+    rep = 1         # K fold procedure can be repeated multiple times
+    Kfold = 5
+    Ntrain = 8528 # number of recordings on training set
+    Nsamp = int(Ntrain/Kfold) # number of recordings to take as validation        
+   
+    # Need to add dimension for training
+    X = np.expand_dims(X, axis=2)
+    classes = ['A', 'N', 'O', '~']
+    Nclass = len(classes)
+    cvconfusion = np.zeros((Nclass,Nclass,Kfold*rep))
+    cvscores = []       
+    counter = 0
+    # repetitions of cross validation
+    for r in range(rep):
+        print("Rep %d"%(r+1))
+        # cross validation loop
+        for k in range(Kfold):
+            print("Cross-validation run %d"%(k+1))
+            # Callbacks definition
+            callbacks = [
+                # Early stopping definition
+                EarlyStopping(monitor='val_loss', patience=3, verbose=1),
+                # Decrease learning rate by 0.1 factor
+                AdvancedLearnignRateScheduler(monitor='val_loss', patience=1,verbose=1, mode='auto', decayRatio=0.1),            
+                # Saving best model
+                ModelCheckpoint('weights-best_k{}_r{}.hdf5'.format(k,r), monitor='val_loss', save_best_only=True, verbose=1),
+                ]
+            # Load model
+            model = ResNet_model(WINDOW_SIZE)
+            
+            # split train and validation sets
+            idxval = np.random.choice(Ntrain, Nsamp,replace=False)
+            idxtrain = np.invert(np.in1d(range(X_train.shape[0]),idxval))
+            ytrain = y[np.asarray(idxtrain),:]
+            Xtrain = X[np.asarray(idxtrain),:,:]         
+            Xval = X[np.asarray(idxval),:,:]
+            yval = y[np.asarray(idxval),:]
+            
+            # Train model
+            model.fit(Xtrain, ytrain,
+                      validation_data=(Xval, yval),
+                      epochs=epochs, batch_size=batch,callbacks=callbacks)
+            
+            # Evaluate best trained model
+            model.load_weights('weights-best_k{}_r{}.hdf5'.format(k,r))
+            ypred = model.predict(Xval)
+            ypred = np.argmax(ypred,axis=1)
+            ytrue = np.argmax(yval,axis=1)
+            cvconfusion[:,:,counter] = confusion_matrix(ytrue, ypred)
+            F1 = np.zeros((4,1))
+            for i in range(4):
+                F1[i]=2*cvconfusion[i,i,counter]/(np.sum(cvconfusion[i,:,counter])+np.sum(cvconfusion[:,i,counter]))
+                print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0]))            
+            cvscores.append(np.mean(F1)* 100)
+            print("Overall F1 measure: {:1.4f}".format(np.mean(F1)))            
+            K.clear_session()
+            gc.collect()
+            config = tf.ConfigProto()
+            config.gpu_options.allow_growth=True            
+            sess = tf.Session(config=config)
+            K.set_session(sess)
+            counter += 1
+    # Saving cross validation results 
+    scipy.io.savemat('xval_results.mat',mdict={'cvconfusion': cvconfusion.tolist()})  
+    return model
+
+###########################
+## Function to load data ##
+###########################
+def loaddata(WINDOW_SIZE):    
+    '''
+        Load training/test data into workspace
+        
+        This function assumes you have downloaded and padded/truncated the 
+        training set into a local file named "trainingset.mat". This file should 
+        contain the following structures:
+            - trainset: NxM matrix of N ECG segments with length M
+            - traintarget: Nx4 matrix of coded labels where each column contains
+            one in case it matches ['A', 'N', 'O', '~'].
+        
+    '''
+    print("Loading data training set")        
+    matfile = scipy.io.loadmat('trainingset.mat')
+    X = matfile['trainset']
+    y = matfile['traintarget']
+    
+    # Merging datasets    
+    # Case other sets are available, load them then concatenate
+    #y = np.concatenate((traintarget,augtarget),axis=0)     
+    #X = np.concatenate((trainset,augset),axis=0)     
+
+    X =  X[:,0:WINDOW_SIZE] 
+    return (X, y)
+
+
+#####################
+# Main function   ##
+###################
+
+config = tf.ConfigProto(allow_soft_placement=True)
+config.gpu_options.allow_growth = True
+sess = tf.Session(config=config)
+seed = 7
+np.random.seed(seed)
+
+# Parameters
+FS = 300
+WINDOW_SIZE = 30*FS     # padding window for CNN
+
+# Loading data
+(X_train,y_train) = loaddata(WINDOW_SIZE)
+
+# Training model
+model = model_eval(X_train,y_train)
+
+# Outputing results of cross validation
+matfile = scipy.io.loadmat('xval_results.mat')
+cv = matfile['cvconfusion']
+F1mean = np.zeros(cv.shape[2])
+for j in range(cv.shape[2]):
+    classes = ['A', 'N', 'O', '~']
+    F1 = np.zeros((4,1))
+    for i in range(4):
+        F1[i]=2*cv[i,i,j]/(np.sum(cv[i,:,j])+np.sum(cv[:,i,j]))        
+        print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0]))
+    F1mean[j] = np.mean(F1)
+    print("mean F1 measure for: {:1.4f}".format(F1mean[j]))
+print("Overall F1 : {:1.4f}".format(np.mean(F1mean)))
+# Plotting confusion matrix
+cvsum = np.sum(cv,axis=2)
+for i in range(4):
+    F1[i]=2*cvsum[i,i]/(np.sum(cvsum[i,:])+np.sum(cvsum[:,i]))        
+    print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0]))
+F1mean = np.mean(F1)
+print("mean F1 measure for: {:1.4f}".format(F1mean))
+plot_confusion_matrix(cvsum, classes,normalize=True,title='Confusion matrix')
+
+