a b/deeplearn-approach/predict.py
1
'''
2
This function loads one random recording from CinC Challenge and use pre-trained model in predicting what it is using Residual Networks
3
4
For more information visit: https://github.com/fernandoandreotti/cinc-challenge2017
5
 
6
 Referencing this work
7
   Andreotti, F., Carr, O., Pimentel, M.A.F., Mahdi, A., & De Vos, M. (2017). Comparing Feature Based 
8
   Classifiers and Convolutional Neural Networks to Detect Arrhythmia from Short Segments of ECG. In 
9
   Computing in Cardiology. Rennes (France).
10
11
--
12
 cinc-challenge2017, version 1.0, Sept 2017
13
 Last updated : 27-09-2017
14
 Released under the GNU General Public License
15
16
 Copyright (C) 2017  Fernando Andreotti, Oliver Carr, Marco A.F. Pimentel, Adam Mahdi, Maarten De Vos
17
 University of Oxford, Department of Engineering Science, Institute of Biomedical Engineering
18
 fernando.andreotti@eng.ox.ac.uk
19
   
20
 This program is free software: you can redistribute it and/or modify
21
 it under the terms of the GNU General Public License as published by
22
 the Free Software Foundation, either version 3 of the License, or
23
 (at your option) any later version.
24
 
25
 This program is distributed in the hope that it will be useful,
26
 but WITHOUT ANY WARRANTY; without even the implied warranty of
27
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
28
 GNU General Public License for more details.
29
 
30
 You should have received a copy of the GNU General Public License
31
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
32
'''
33
34
35
# Download some random waveform from challenge database
36
from random import randint
37
import urllib.request
38
record = "A{:05d}".format(randint(0, 999))
39
urlfile = "https://www.physionet.org/physiobank/database/challenge/2017/training/A00/{}.mat".format(record)
40
local_filename, headers = urllib.request.urlretrieve(urlfile)
41
html = open(local_filename)
42
print('Downloading record {} ..'.format(record))
43
   
44
# Load data
45
import scipy.io
46
mat_data = scipy.io.loadmat(local_filename)
47
data = mat_data['val']
48
49
# Parameters
50
FS = 300
51
maxlen = 30*FS
52
classes = ['A', 'N', 'O','~']
53
54
# Preprocessing data
55
print("Preprocessing recording ..")    
56
import numpy as np
57
X = np.zeros((1,maxlen))
58
data = np.nan_to_num(data) # removing NaNs and Infs
59
data = data[0,0:maxlen]
60
data = data - np.mean(data)
61
data = data/np.std(data)
62
X[0,:len(data)] = data.T # padding sequence
63
data = X
64
data = np.expand_dims(data, axis=2) # required by Keras
65
del X
66
67
68
# Load and apply model
69
print("Loading model")    
70
from keras.models import load_model
71
model = load_model('ResNet_30s_34lay_16conv.hdf5')
72
73
print("Applying model ..")    
74
prob = model.predict(data)
75
ann = np.argmax(prob)
76
print("Record {} classified as {} with {:3.1f}% certainty".format(record,classes[ann],100*prob[0,ann]))
77
78
# Visualising output of first 16 convolutions for some layers
79
from keras import backend as K
80
import matplotlib.pyplot as plt
81
plt.plot(data[0,0:1000,0],)
82
plt.title('Input signal')
83
#plt.savefig('layinput.eps', format='eps', dpi=1000) # saving?
84
85
for l in range(1,34):#range(1,34):
86
    Np = 1000
87
    ## Example of plotting first layer output
88
    layer_name = 'conv1d_{}'.format(l)
89
    layer_dict = dict([(layer.name, layer) for layer in model.layers])
90
    layer_output = layer_dict[layer_name].output
91
    
92
    # K.learning_phase() is a flag that indicates if the network is in training or
93
    # predict phase. It allow layer (e.g. Dropout) to only be applied during training
94
    get_layer_output = K.function([model.layers[0].input, K.learning_phase()],
95
                                   [layer_output])
96
    filtout = get_layer_output([data,0])[0]
97
    Npnew = int(Np*filtout.shape[1]/data.shape[1])
98
    fig, ax = plt.subplots(nrows=4, ncols=4, sharex='col', sharey='row')    
99
    count = 0
100
    for row in ax:
101
        for col in row:
102
            col.plot(range(Npnew), filtout[0,0:Npnew,count],linewidth=1.0,color='olive')
103
            count += 1
104
    plt.suptitle('Layer {}'.format(l))
105
    #plt.savefig('layoutput{}.eps'.format(l), format='eps', dpi=1000) # saving?
106
            
107
108