|
a |
|
b/deeplearn-approach/train_model.py |
|
|
1 |
''' |
|
|
2 |
This function function used for training and cross-validating model using. The database is not |
|
|
3 |
included in this repo, please download the CinC Challenge database and truncate/pad data into a |
|
|
4 |
NxM matrix array, being N the number of recordings and M the window accepted by the network (i.e. |
|
|
5 |
30 seconds). |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
For more information visit: https://github.com/fernandoandreotti/cinc-challenge2017 |
|
|
9 |
|
|
|
10 |
Referencing this work |
|
|
11 |
Andreotti, F., Carr, O., Pimentel, M.A.F., Mahdi, A., & De Vos, M. (2017). Comparing Feature Based |
|
|
12 |
Classifiers and Convolutional Neural Networks to Detect Arrhythmia from Short Segments of ECG. In |
|
|
13 |
Computing in Cardiology. Rennes (France). |
|
|
14 |
-- |
|
|
15 |
cinc-challenge2017, version 1.0, Sept 2017 |
|
|
16 |
Last updated : 27-09-2017 |
|
|
17 |
Released under the GNU General Public License |
|
|
18 |
Copyright (C) 2017 Fernando Andreotti, Oliver Carr, Marco A.F. Pimentel, Adam Mahdi, Maarten De Vos |
|
|
19 |
University of Oxford, Department of Engineering Science, Institute of Biomedical Engineering |
|
|
20 |
fernando.andreotti@eng.ox.ac.uk |
|
|
21 |
|
|
|
22 |
This program is free software: you can redistribute it and/or modify |
|
|
23 |
it under the terms of the GNU General Public License as published by |
|
|
24 |
the Free Software Foundation, either version 3 of the License, or |
|
|
25 |
(at your option) any later version. |
|
|
26 |
|
|
|
27 |
This program is distributed in the hope that it will be useful, |
|
|
28 |
but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
|
29 |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
|
30 |
GNU General Public License for more details. |
|
|
31 |
|
|
|
32 |
You should have received a copy of the GNU General Public License |
|
|
33 |
along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
|
34 |
''' |
|
|
35 |
|
|
|
36 |
import matplotlib.pyplot as plt |
|
|
37 |
import tensorflow as tf |
|
|
38 |
import numpy as np |
|
|
39 |
import scipy.io |
|
|
40 |
import gc |
|
|
41 |
import itertools |
|
|
42 |
from sklearn.metrics import confusion_matrix |
|
|
43 |
import sys |
|
|
44 |
sys.path.insert(0, './preparation') |
|
|
45 |
|
|
|
46 |
# Keras imports |
|
|
47 |
import keras |
|
|
48 |
from keras.models import Model |
|
|
49 |
from keras.layers import Input, Conv1D, Dense, Flatten, Dropout,MaxPooling1D, Activation, BatchNormalization |
|
|
50 |
from keras.callbacks import EarlyStopping, ModelCheckpoint |
|
|
51 |
from keras.utils import plot_model |
|
|
52 |
from keras import backend as K |
|
|
53 |
from keras.callbacks import Callback,warnings |
|
|
54 |
|
|
|
55 |
################################################################### |
|
|
56 |
### Callback method for reducing learning rate during training ### |
|
|
57 |
################################################################### |
|
|
58 |
class AdvancedLearnignRateScheduler(Callback): |
|
|
59 |
''' |
|
|
60 |
# Arguments |
|
|
61 |
monitor: quantity to be monitored. |
|
|
62 |
patience: number of epochs with no improvement |
|
|
63 |
after which training will be stopped. |
|
|
64 |
verbose: verbosity mode. |
|
|
65 |
mode: one of {auto, min, max}. In 'min' mode, |
|
|
66 |
training will stop when the quantity |
|
|
67 |
monitored has stopped decreasing; in 'max' |
|
|
68 |
mode it will stop when the quantity |
|
|
69 |
monitored has stopped increasing. |
|
|
70 |
''' |
|
|
71 |
def __init__(self, monitor='val_loss', patience=0,verbose=0, mode='auto', decayRatio=0.1): |
|
|
72 |
super(Callback, self).__init__() |
|
|
73 |
self.monitor = monitor |
|
|
74 |
self.patience = patience |
|
|
75 |
self.verbose = verbose |
|
|
76 |
self.wait = 0 |
|
|
77 |
self.decayRatio = decayRatio |
|
|
78 |
|
|
|
79 |
if mode not in ['auto', 'min', 'max']: |
|
|
80 |
warnings.warn('Mode %s is unknown, ' |
|
|
81 |
'fallback to auto mode.' |
|
|
82 |
% (self.mode), RuntimeWarning) |
|
|
83 |
mode = 'auto' |
|
|
84 |
|
|
|
85 |
if mode == 'min': |
|
|
86 |
self.monitor_op = np.less |
|
|
87 |
self.best = np.Inf |
|
|
88 |
elif mode == 'max': |
|
|
89 |
self.monitor_op = np.greater |
|
|
90 |
self.best = -np.Inf |
|
|
91 |
else: |
|
|
92 |
if 'acc' in self.monitor: |
|
|
93 |
self.monitor_op = np.greater |
|
|
94 |
self.best = -np.Inf |
|
|
95 |
else: |
|
|
96 |
self.monitor_op = np.less |
|
|
97 |
self.best = np.Inf |
|
|
98 |
|
|
|
99 |
def on_epoch_end(self, epoch, logs={}): |
|
|
100 |
current = logs.get(self.monitor) |
|
|
101 |
current_lr = K.get_value(self.model.optimizer.lr) |
|
|
102 |
print("\nLearning rate:", current_lr) |
|
|
103 |
if current is None: |
|
|
104 |
warnings.warn('AdvancedLearnignRateScheduler' |
|
|
105 |
' requires %s available!' % |
|
|
106 |
(self.monitor), RuntimeWarning) |
|
|
107 |
|
|
|
108 |
if self.monitor_op(current, self.best): |
|
|
109 |
self.best = current |
|
|
110 |
self.wait = 0 |
|
|
111 |
else: |
|
|
112 |
if self.wait >= self.patience: |
|
|
113 |
if self.verbose > 0: |
|
|
114 |
print('\nEpoch %05d: reducing learning rate' % (epoch)) |
|
|
115 |
assert hasattr(self.model.optimizer, 'lr'), \ |
|
|
116 |
'Optimizer must have a "lr" attribute.' |
|
|
117 |
current_lr = K.get_value(self.model.optimizer.lr) |
|
|
118 |
new_lr = current_lr * self.decayRatio |
|
|
119 |
K.set_value(self.model.optimizer.lr, new_lr) |
|
|
120 |
self.wait = 0 |
|
|
121 |
self.wait += 1 |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
########################################### |
|
|
125 |
## Function to plot confusion matrices ## |
|
|
126 |
######################################### |
|
|
127 |
def plot_confusion_matrix(cm, classes, |
|
|
128 |
normalize=False, |
|
|
129 |
title='Confusion matrix', |
|
|
130 |
cmap=plt.cm.Blues): |
|
|
131 |
""" |
|
|
132 |
This function prints and plots the confusion matrix. |
|
|
133 |
Normalization can be applied by setting `normalize=True`. |
|
|
134 |
""" |
|
|
135 |
if normalize: |
|
|
136 |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
|
|
137 |
print("Normalized confusion matrix") |
|
|
138 |
else: |
|
|
139 |
print('Confusion matrix, without normalization') |
|
|
140 |
cm = np.around(cm, decimals=3) |
|
|
141 |
print(cm) |
|
|
142 |
|
|
|
143 |
thresh = cm.max() / 2. |
|
|
144 |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
|
|
145 |
plt.text(j, i, cm[i, j], |
|
|
146 |
horizontalalignment="center", |
|
|
147 |
color="white" if cm[i, j] > thresh else "black") |
|
|
148 |
|
|
|
149 |
plt.imshow(cm, interpolation='nearest', cmap=cmap) |
|
|
150 |
plt.title(title) |
|
|
151 |
plt.colorbar() |
|
|
152 |
tick_marks = np.arange(len(classes)) |
|
|
153 |
plt.xticks(tick_marks, classes, rotation=45) |
|
|
154 |
plt.yticks(tick_marks, classes) |
|
|
155 |
plt.tight_layout() |
|
|
156 |
plt.ylabel('True label') |
|
|
157 |
plt.xlabel('Predicted label') |
|
|
158 |
plt.savefig('confusion.eps', format='eps', dpi=1000) |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
##################################### |
|
|
162 |
## Model definition ## |
|
|
163 |
## ResNet based on Rajpurkar ## |
|
|
164 |
################################## |
|
|
165 |
def ResNet_model(WINDOW_SIZE): |
|
|
166 |
# Add CNN layers left branch (higher frequencies) |
|
|
167 |
# Parameters from paper |
|
|
168 |
INPUT_FEAT = 1 |
|
|
169 |
OUTPUT_CLASS = 4 # output classes |
|
|
170 |
|
|
|
171 |
k = 1 # increment every 4th residual block |
|
|
172 |
p = True # pool toggle every other residual block (end with 2^8) |
|
|
173 |
convfilt = 64 |
|
|
174 |
convstr = 1 |
|
|
175 |
ksize = 16 |
|
|
176 |
poolsize = 2 |
|
|
177 |
poolstr = 2 |
|
|
178 |
drop = 0.5 |
|
|
179 |
|
|
|
180 |
# Modelling with Functional API |
|
|
181 |
#input1 = Input(shape=(None,1), name='input') |
|
|
182 |
input1 = Input(shape=(WINDOW_SIZE,INPUT_FEAT), name='input') |
|
|
183 |
|
|
|
184 |
## First convolutional block (conv,BN, relu) |
|
|
185 |
x = Conv1D(filters=convfilt, |
|
|
186 |
kernel_size=ksize, |
|
|
187 |
padding='same', |
|
|
188 |
strides=convstr, |
|
|
189 |
kernel_initializer='he_normal')(input1) |
|
|
190 |
x = BatchNormalization()(x) |
|
|
191 |
x = Activation('relu')(x) |
|
|
192 |
|
|
|
193 |
## Second convolutional block (conv, BN, relu, dropout, conv) with residual net |
|
|
194 |
# Left branch (convolutions) |
|
|
195 |
x1 = Conv1D(filters=convfilt, |
|
|
196 |
kernel_size=ksize, |
|
|
197 |
padding='same', |
|
|
198 |
strides=convstr, |
|
|
199 |
kernel_initializer='he_normal')(x) |
|
|
200 |
x1 = BatchNormalization()(x1) |
|
|
201 |
x1 = Activation('relu')(x1) |
|
|
202 |
x1 = Dropout(drop)(x1) |
|
|
203 |
x1 = Conv1D(filters=convfilt, |
|
|
204 |
kernel_size=ksize, |
|
|
205 |
padding='same', |
|
|
206 |
strides=convstr, |
|
|
207 |
kernel_initializer='he_normal')(x1) |
|
|
208 |
x1 = MaxPooling1D(pool_size=poolsize, |
|
|
209 |
strides=poolstr)(x1) |
|
|
210 |
# Right branch, shortcut branch pooling |
|
|
211 |
x2 = MaxPooling1D(pool_size=poolsize, |
|
|
212 |
strides=poolstr)(x) |
|
|
213 |
# Merge both branches |
|
|
214 |
x = keras.layers.add([x1, x2]) |
|
|
215 |
del x1,x2 |
|
|
216 |
|
|
|
217 |
## Main loop |
|
|
218 |
p = not p |
|
|
219 |
for l in range(15): |
|
|
220 |
|
|
|
221 |
if (l%4 == 0) and (l>0): # increment k on every fourth residual block |
|
|
222 |
k += 1 |
|
|
223 |
# increase depth by 1x1 Convolution case dimension shall change |
|
|
224 |
xshort = Conv1D(filters=convfilt*k,kernel_size=1)(x) |
|
|
225 |
else: |
|
|
226 |
xshort = x |
|
|
227 |
# Left branch (convolutions) |
|
|
228 |
# notice the ordering of the operations has changed |
|
|
229 |
x1 = BatchNormalization()(x) |
|
|
230 |
x1 = Activation('relu')(x1) |
|
|
231 |
x1 = Dropout(drop)(x1) |
|
|
232 |
x1 = Conv1D(filters=convfilt*k, |
|
|
233 |
kernel_size=ksize, |
|
|
234 |
padding='same', |
|
|
235 |
strides=convstr, |
|
|
236 |
kernel_initializer='he_normal')(x1) |
|
|
237 |
x1 = BatchNormalization()(x1) |
|
|
238 |
x1 = Activation('relu')(x1) |
|
|
239 |
x1 = Dropout(drop)(x1) |
|
|
240 |
x1 = Conv1D(filters=convfilt*k, |
|
|
241 |
kernel_size=ksize, |
|
|
242 |
padding='same', |
|
|
243 |
strides=convstr, |
|
|
244 |
kernel_initializer='he_normal')(x1) |
|
|
245 |
if p: |
|
|
246 |
x1 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(x1) |
|
|
247 |
|
|
|
248 |
# Right branch: shortcut connection |
|
|
249 |
if p: |
|
|
250 |
x2 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(xshort) |
|
|
251 |
else: |
|
|
252 |
x2 = xshort # pool or identity |
|
|
253 |
# Merging branches |
|
|
254 |
x = keras.layers.add([x1, x2]) |
|
|
255 |
# change parameters |
|
|
256 |
p = not p # toggle pooling |
|
|
257 |
|
|
|
258 |
|
|
|
259 |
# Final bit |
|
|
260 |
x = BatchNormalization()(x) |
|
|
261 |
x = Activation('relu')(x) |
|
|
262 |
x = Flatten()(x) |
|
|
263 |
#x = Dense(1000)(x) |
|
|
264 |
#x = Dense(1000)(x) |
|
|
265 |
out = Dense(OUTPUT_CLASS, activation='softmax')(x) |
|
|
266 |
model = Model(inputs=input1, outputs=out) |
|
|
267 |
model.compile(optimizer='adam', |
|
|
268 |
loss='categorical_crossentropy', |
|
|
269 |
metrics=['accuracy']) |
|
|
270 |
#model.summary() |
|
|
271 |
#sequential_model_to_ascii_printout(model) |
|
|
272 |
plot_model(model, to_file='model.png') |
|
|
273 |
return model |
|
|
274 |
|
|
|
275 |
########################################################### |
|
|
276 |
## Function to perform K-fold Crossvalidation on model ## |
|
|
277 |
########################################################## |
|
|
278 |
def model_eval(X,y): |
|
|
279 |
batch =64 |
|
|
280 |
epochs = 20 |
|
|
281 |
rep = 1 # K fold procedure can be repeated multiple times |
|
|
282 |
Kfold = 5 |
|
|
283 |
Ntrain = 8528 # number of recordings on training set |
|
|
284 |
Nsamp = int(Ntrain/Kfold) # number of recordings to take as validation |
|
|
285 |
|
|
|
286 |
# Need to add dimension for training |
|
|
287 |
X = np.expand_dims(X, axis=2) |
|
|
288 |
classes = ['A', 'N', 'O', '~'] |
|
|
289 |
Nclass = len(classes) |
|
|
290 |
cvconfusion = np.zeros((Nclass,Nclass,Kfold*rep)) |
|
|
291 |
cvscores = [] |
|
|
292 |
counter = 0 |
|
|
293 |
# repetitions of cross validation |
|
|
294 |
for r in range(rep): |
|
|
295 |
print("Rep %d"%(r+1)) |
|
|
296 |
# cross validation loop |
|
|
297 |
for k in range(Kfold): |
|
|
298 |
print("Cross-validation run %d"%(k+1)) |
|
|
299 |
# Callbacks definition |
|
|
300 |
callbacks = [ |
|
|
301 |
# Early stopping definition |
|
|
302 |
EarlyStopping(monitor='val_loss', patience=3, verbose=1), |
|
|
303 |
# Decrease learning rate by 0.1 factor |
|
|
304 |
AdvancedLearnignRateScheduler(monitor='val_loss', patience=1,verbose=1, mode='auto', decayRatio=0.1), |
|
|
305 |
# Saving best model |
|
|
306 |
ModelCheckpoint('weights-best_k{}_r{}.hdf5'.format(k,r), monitor='val_loss', save_best_only=True, verbose=1), |
|
|
307 |
] |
|
|
308 |
# Load model |
|
|
309 |
model = ResNet_model(WINDOW_SIZE) |
|
|
310 |
|
|
|
311 |
# split train and validation sets |
|
|
312 |
idxval = np.random.choice(Ntrain, Nsamp,replace=False) |
|
|
313 |
idxtrain = np.invert(np.in1d(range(X_train.shape[0]),idxval)) |
|
|
314 |
ytrain = y[np.asarray(idxtrain),:] |
|
|
315 |
Xtrain = X[np.asarray(idxtrain),:,:] |
|
|
316 |
Xval = X[np.asarray(idxval),:,:] |
|
|
317 |
yval = y[np.asarray(idxval),:] |
|
|
318 |
|
|
|
319 |
# Train model |
|
|
320 |
model.fit(Xtrain, ytrain, |
|
|
321 |
validation_data=(Xval, yval), |
|
|
322 |
epochs=epochs, batch_size=batch,callbacks=callbacks) |
|
|
323 |
|
|
|
324 |
# Evaluate best trained model |
|
|
325 |
model.load_weights('weights-best_k{}_r{}.hdf5'.format(k,r)) |
|
|
326 |
ypred = model.predict(Xval) |
|
|
327 |
ypred = np.argmax(ypred,axis=1) |
|
|
328 |
ytrue = np.argmax(yval,axis=1) |
|
|
329 |
cvconfusion[:,:,counter] = confusion_matrix(ytrue, ypred) |
|
|
330 |
F1 = np.zeros((4,1)) |
|
|
331 |
for i in range(4): |
|
|
332 |
F1[i]=2*cvconfusion[i,i,counter]/(np.sum(cvconfusion[i,:,counter])+np.sum(cvconfusion[:,i,counter])) |
|
|
333 |
print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0])) |
|
|
334 |
cvscores.append(np.mean(F1)* 100) |
|
|
335 |
print("Overall F1 measure: {:1.4f}".format(np.mean(F1))) |
|
|
336 |
K.clear_session() |
|
|
337 |
gc.collect() |
|
|
338 |
config = tf.ConfigProto() |
|
|
339 |
config.gpu_options.allow_growth=True |
|
|
340 |
sess = tf.Session(config=config) |
|
|
341 |
K.set_session(sess) |
|
|
342 |
counter += 1 |
|
|
343 |
# Saving cross validation results |
|
|
344 |
scipy.io.savemat('xval_results.mat',mdict={'cvconfusion': cvconfusion.tolist()}) |
|
|
345 |
return model |
|
|
346 |
|
|
|
347 |
########################### |
|
|
348 |
## Function to load data ## |
|
|
349 |
########################### |
|
|
350 |
def loaddata(WINDOW_SIZE): |
|
|
351 |
''' |
|
|
352 |
Load training/test data into workspace |
|
|
353 |
|
|
|
354 |
This function assumes you have downloaded and padded/truncated the |
|
|
355 |
training set into a local file named "trainingset.mat". This file should |
|
|
356 |
contain the following structures: |
|
|
357 |
- trainset: NxM matrix of N ECG segments with length M |
|
|
358 |
- traintarget: Nx4 matrix of coded labels where each column contains |
|
|
359 |
one in case it matches ['A', 'N', 'O', '~']. |
|
|
360 |
|
|
|
361 |
''' |
|
|
362 |
print("Loading data training set") |
|
|
363 |
matfile = scipy.io.loadmat('trainingset.mat') |
|
|
364 |
X = matfile['trainset'] |
|
|
365 |
y = matfile['traintarget'] |
|
|
366 |
|
|
|
367 |
# Merging datasets |
|
|
368 |
# Case other sets are available, load them then concatenate |
|
|
369 |
#y = np.concatenate((traintarget,augtarget),axis=0) |
|
|
370 |
#X = np.concatenate((trainset,augset),axis=0) |
|
|
371 |
|
|
|
372 |
X = X[:,0:WINDOW_SIZE] |
|
|
373 |
return (X, y) |
|
|
374 |
|
|
|
375 |
|
|
|
376 |
##################### |
|
|
377 |
# Main function ## |
|
|
378 |
################### |
|
|
379 |
|
|
|
380 |
config = tf.ConfigProto(allow_soft_placement=True) |
|
|
381 |
config.gpu_options.allow_growth = True |
|
|
382 |
sess = tf.Session(config=config) |
|
|
383 |
seed = 7 |
|
|
384 |
np.random.seed(seed) |
|
|
385 |
|
|
|
386 |
# Parameters |
|
|
387 |
FS = 300 |
|
|
388 |
WINDOW_SIZE = 30*FS # padding window for CNN |
|
|
389 |
|
|
|
390 |
# Loading data |
|
|
391 |
(X_train,y_train) = loaddata(WINDOW_SIZE) |
|
|
392 |
|
|
|
393 |
# Training model |
|
|
394 |
model = model_eval(X_train,y_train) |
|
|
395 |
|
|
|
396 |
# Outputing results of cross validation |
|
|
397 |
matfile = scipy.io.loadmat('xval_results.mat') |
|
|
398 |
cv = matfile['cvconfusion'] |
|
|
399 |
F1mean = np.zeros(cv.shape[2]) |
|
|
400 |
for j in range(cv.shape[2]): |
|
|
401 |
classes = ['A', 'N', 'O', '~'] |
|
|
402 |
F1 = np.zeros((4,1)) |
|
|
403 |
for i in range(4): |
|
|
404 |
F1[i]=2*cv[i,i,j]/(np.sum(cv[i,:,j])+np.sum(cv[:,i,j])) |
|
|
405 |
print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0])) |
|
|
406 |
F1mean[j] = np.mean(F1) |
|
|
407 |
print("mean F1 measure for: {:1.4f}".format(F1mean[j])) |
|
|
408 |
print("Overall F1 : {:1.4f}".format(np.mean(F1mean))) |
|
|
409 |
# Plotting confusion matrix |
|
|
410 |
cvsum = np.sum(cv,axis=2) |
|
|
411 |
for i in range(4): |
|
|
412 |
F1[i]=2*cvsum[i,i]/(np.sum(cvsum[i,:])+np.sum(cvsum[:,i])) |
|
|
413 |
print("F1 measure for {} rhythm: {:1.4f}".format(classes[i],F1[i,0])) |
|
|
414 |
F1mean = np.mean(F1) |
|
|
415 |
print("mean F1 measure for: {:1.4f}".format(F1mean)) |
|
|
416 |
plot_confusion_matrix(cvsum, classes,normalize=True,title='Confusion matrix') |
|
|
417 |
|
|
|
418 |
|