Diff of /ecg_annotation.py [000000] .. [a2673c]

Switch to unified view

a b/ecg_annotation.py
1
#import matplotlib with pdf as backend
2
import matplotlib 
3
matplotlib.use('PDF')
4
import matplotlib.pyplot as plt
5
from matplotlib.backends.backend_pdf import PdfPages
6
7
import wfdb 
8
import os
9
import numpy as np
10
import math
11
import sys
12
import scipy.stats as st
13
import glob, os
14
from os.path import basename
15
16
17
import tensorflow as tf
18
from keras.layers import Dense,Activation,Dropout
19
from keras.layers import LSTM,Bidirectional #could try TimeDistributed(Dense(...))
20
from keras.models import Sequential, load_model
21
from keras import optimizers,regularizers
22
from keras.layers.normalization import BatchNormalization
23
import keras.backend.tensorflow_backend as KTF
24
np.random.seed(0)
25
26
# functions
27
def get_ecg_data(datfile): 
28
    ## convert .dat/q1c to numpy arrays
29
    recordname=os.path.basename(datfile).split(".dat")[0]
30
    recordpath=os.path.dirname(datfile)
31
    cwd=os.getcwd()
32
    os.chdir(recordpath) ## somehow it only works if you chdir. 
33
34
    annotator='q1c'
35
    annotation = wfdb.rdann(recordname, extension=annotator, sampfrom=0,sampto = None, pbdir=None)
36
    Lstannot=list(zip(annotation.sample,annotation.symbol,annotation.aux_note))
37
38
    FirstLstannot=min( i[0] for i in Lstannot)
39
    LastLstannot=max( i[0] for i in Lstannot)-1
40
    print("first-last annotation:", FirstLstannot,LastLstannot)
41
    
42
    record = wfdb.rdsamp(recordname, sampfrom=FirstLstannot,sampto = LastLstannot) #wfdb.showanncodes()
43
    annotation = wfdb.rdann(recordname, annotator, sampfrom=FirstLstannot,sampto = LastLstannot) ## get annotation between first and last. 
44
    annotation2 = wfdb.Annotation(recordname='sel32', extension='niek', sample=(annotation.sample-FirstLstannot), symbol = annotation.symbol, aux_note=annotation.aux_note)
45
46
    Vctrecord=np.transpose(record.p_signals)
47
    VctAnnotationHot=np.zeros( (6,len(Vctrecord[1])), dtype=np.int)
48
    VctAnnotationHot[5]=1 ## inverse of the others 
49
    #print("ecg, 2 lead of shape" , Vctrecord.shape) 
50
    #print("VctAnnotationHot of shape" , VctAnnotationHot.shape) 
51
    #print('plotting extracted signal with annotation')
52
    #wfdb.plotrec(record, annotation=annotation2, title='Record 100 from MIT-BIH Arrhythmia Database', timeunits = 'seconds')
53
54
    VctAnnotations=list(zip(annotation2.sample,annotation2.symbol)) ## zip coordinates + annotations (N),(t) etc)
55
    #print(VctAnnotations)
56
    for i in range(len(VctAnnotations)):
57
        #print(VctAnnotations[i]) # Print to display annotations of an ecg
58
        try: 
59
            
60
            if VctAnnotations[i][1]=="p":
61
                if VctAnnotations[i-1][1]=="(":
62
                    pstart=VctAnnotations[i-1][0]
63
                if VctAnnotations[i+1][1]==")":
64
                    pend=VctAnnotations[i+1][0]
65
                if VctAnnotations[i+3][1]=="N":
66
                    rpos=VctAnnotations[i+3][0]
67
                    if VctAnnotations[i+2][1]=="(":
68
                        qpos=VctAnnotations[i+2][0]
69
                    if VctAnnotations[i+4][1]==")":
70
                        spos=VctAnnotations[i+4][0] 
71
                    for ii in range(0,8): ## search for t (sometimes the "(" for the t  is missing  )
72
                        if VctAnnotations[i+ii][1]=="t":
73
                            tpos=VctAnnotations[i+ii][0]
74
                            if VctAnnotations[i+ii+1][1]==")":
75
                                tendpos=VctAnnotations[i+ii+1][0]
76
            #               #print(ppos,qpos,rpos,spos,tendpos)
77
                                VctAnnotationHot[0][pstart:pend]=1 #P segment
78
                                VctAnnotationHot[1][pend:qpos]=1 #part "nothing" between P and Q, previously left unnanotated, but categorical probably can't deal with that
79
                                VctAnnotationHot[2][qpos:rpos]=1 #QR
80
                                VctAnnotationHot[3][rpos:spos]=1 #RS
81
                                VctAnnotationHot[4][spos:tendpos]=1 #ST (from end of S to end of T)
82
                                VctAnnotationHot[5][pstart:tendpos]=0 #tendpos:pstart becomes 1, because it is inverted above                   
83
        except IndexError:
84
            pass
85
    
86
    Vctrecord=np.transpose(Vctrecord) # transpose to (timesteps,feat)
87
    VctAnnotationHot=np.transpose(VctAnnotationHot)
88
    os.chdir(cwd)
89
    return Vctrecord, VctAnnotationHot
90
91
92
93
def splitseq(x,n,o):
94
    #split seq; should be optimized so that remove_seq_gaps is not needed. 
95
    upper=math.ceil( x.shape[0] / n) *n
96
    print("splitting on",n,"with overlap of ",o,    "total datapoints:",x.shape[0],"; upper:",upper)
97
    for i in range(0,upper,n):
98
        #print(i)
99
        if i==0:
100
            padded=np.zeros( ( o+n+o,x.shape[1])   ) ## pad with 0's on init
101
            padded[o:,:x.shape[1]] = x[i:i+n+o,:]
102
            xpart=padded
103
        else:
104
            xpart=x[i-o:i+n+o,:]
105
        if xpart.shape[0]<i:
106
107
            padded=np.zeros( (o+n+o,xpart.shape[1])  ) ## pad with 0's on end of seq
108
            padded[:xpart.shape[0],:xpart.shape[1]] = xpart
109
            xpart=padded
110
111
        xpart=np.expand_dims(xpart,0)## add one dimension; so that you get shape (samples,timesteps,features)
112
        try:
113
            xx=np.vstack(  (xx,xpart) )
114
        except UnboundLocalError: ## on init
115
            xx=xpart
116
    print("output: ",xx.shape)
117
    return(xx)
118
119
def remove_seq_gaps(x,y):
120
    #remove parts that are not annotated <- not ideal, but quickest for now.
121
    window=150
122
    c=0
123
    cutout=[]
124
    include=[]
125
    print("filterering.")
126
    print("before shape x,y",x.shape,y.shape)
127
    for i in range(y.shape[0]):
128
        
129
        c=c+1
130
        if c<window :
131
            include.append(i)
132
        if sum(y[i,0:5])>0:
133
            c=0 
134
        if c >= window:
135
            #print ('filtering')
136
            pass
137
    x,y=x[include,:],y[include,:]
138
    print(" after shape x,y",x.shape,y.shape)
139
    return(x,y)
140
141
142
def normalizesignal(x):
143
    x=st.zscore(x, ddof=0)
144
    return x
145
def normalizesignal_array(x):
146
    for i in range(x.shape[0]):
147
        x[i]=st.zscore(x[i], axis=0, ddof=0)
148
    return x
149
150
def plotecg(x,y,begin,end):
151
    #helper to plot ecg
152
    plt.figure(1,figsize=(11.69,8.27))
153
    plt.subplot(211)
154
    plt.plot(x[begin:end,0])
155
    plt.subplot(211)
156
    plt.plot(y[begin:end,0])
157
    plt.subplot(211)
158
    plt.plot(y[begin:end,1])
159
    plt.subplot(211)
160
    plt.plot(y[begin:end,2])
161
    plt.subplot(211)
162
    plt.plot(y[begin:end,3])
163
    plt.subplot(211)
164
    plt.plot(y[begin:end,4])
165
    plt.subplot(211)
166
    plt.plot(y[begin:end,5])
167
168
    plt.subplot(212)
169
    plt.plot(x[begin:end,1])
170
    plt.show()
171
172
def plotecg_validation(x,y_true,y_pred,begin,end):
173
    #helper to plot ecg
174
    plt.figure(1,figsize=(11.69,8.27))
175
    plt.subplot(211)
176
    plt.plot(x[begin:end,0])
177
    plt.subplot(211)
178
    plt.plot(y_pred[begin:end,0])
179
    plt.subplot(211)
180
    plt.plot(y_pred[begin:end,1])
181
    plt.subplot(211)
182
    plt.plot(y_pred[begin:end,2])
183
    plt.subplot(211)
184
    plt.plot(y_pred[begin:end,3])
185
    plt.subplot(211)
186
    plt.plot(y_pred[begin:end,4])
187
    plt.subplot(211)
188
    plt.plot(y_pred[begin:end,5])
189
190
    plt.subplot(212)
191
    plt.plot(x[begin:end,1])
192
    plt.subplot(212)
193
    plt.plot(y_true[begin:end,0])
194
    plt.subplot(212)
195
    plt.plot(y_true[begin:end,1])
196
    plt.subplot(212)
197
    plt.plot(y_true[begin:end,2])
198
    plt.subplot(212)
199
    plt.plot(y_true[begin:end,3])
200
    plt.subplot(212)
201
    plt.plot(y_true[begin:end,4])
202
    plt.subplot(212)
203
    plt.plot(y_true[begin:end,5])
204
    
205
def LoaddDatFiles(datfiles):  
206
    for datfile in datfiles:
207
        print(datfile)
208
        if basename(datfile).split(".",1)[0] in exclude:
209
            continue
210
        qf=os.path.splitext(datfile)[0]+'.q1c'
211
        if os.path.isfile(qf):
212
            #print("yes",qf,datfile)
213
            x,y=get_ecg_data(datfile)
214
            x,y=remove_seq_gaps(x,y)
215
216
            x,y=splitseq(x,1000,150),splitseq(y,1000,150) ## create equal sized numpy arrays of n size and overlap of o 
217
218
            x = normalizesignal_array(x)
219
            ## todo; add noise, shuffle leads etc. ?
220
            try: ## concat
221
                xx=np.vstack(  (xx,x) )
222
                yy=np.vstack(  (yy,y) )
223
            except NameError: ## if xx does not exist yet (on init)
224
                xx = x
225
                yy = y
226
    return(xx,yy)
227
228
def unison_shuffled_copies(a, b):
229
    assert len(a) == len(b)
230
    p = np.random.permutation(len(a))
231
    return a[p], b[p]
232
233
def get_session(gpu_fraction=0.8):
234
    #allocate % of gpu memory.
235
    num_threads = os.environ.get('OMP_NUM_THREADS')
236
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
237
    if num_threads:
238
        return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
239
    else:
240
        return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
241
242
def getmodel():
243
    model = Sequential()
244
    model.add(Dense(32,W_regularizer=regularizers.l2(l=0.01), input_shape=(seqlength, features)))
245
    model.add(Bidirectional(LSTM(32, return_sequences=True)))#, input_shape=(seqlength, features)) ) ### bidirectional ---><---
246
    model.add(Dropout(0.2))
247
    model.add(BatchNormalization())
248
    model.add(Dense(64, activation='relu',W_regularizer=regularizers.l2(l=0.01)))
249
    model.add(Dropout(0.2))
250
    model.add(BatchNormalization())
251
    model.add(Dense(dimout, activation='softmax'))
252
    adam = optimizers.adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
253
    model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy']) 
254
    print(model.summary())
255
    return(model)
256
257
##################################################################
258
##################################################################
259
qtdbpath=sys.argv[1] ## first argument = qtdb database from physionet. 
260
perct=0.81 #percentage training
261
percv=0.19 #percentage validation
262
263
exclude = set()
264
exclude.update(["sel35","sel36","sel37","sel50","sel102","sel104","sel221","sel232", "sel310"])# no P annotated:
265
##################################################################
266
# datfile=qtdbpath+"sel49.dat"  ## single ECG to test if loading works.  
267
# x,y=get_ecg_data(datfile)
268
# print(x.shape,y.shape)
269
# # for i in range(y.shape[0]): #Invert QT-label to actually represent QT. Does give overlapping labels
270
# #     y[i][0] = 1 - y[i][0]
271
# plotecg(x,y,0,y.shape[0]) ## plot all
272
# x,y=remove_seq_gaps(x,y) ## remove 'annotation gaps'
273
# plotecg(x,y,0,y.shape[0]) ## plot all
274
# x,y=splitseq(x,750,150),splitseq(y,750,150) ## create equal sized numpy arrays of n size and overlap of o 
275
# exit()
276
##################################################################
277
278
# load data
279
datfiles=glob.glob(qtdbpath+"*.dat")
280
xxt,yyt=LoaddDatFiles(datfiles[ :round(len(datfiles)*perct) ]) # training data. 
281
xxt,yyt=unison_shuffled_copies(xxt,yyt) ### shuffle
282
xxv,yyv=LoaddDatFiles(datfiles[ -round(len(datfiles)*percv): ] ) ## validation data.
283
seqlength=xxt.shape[1]
284
features=xxt.shape[2]
285
dimout=yyt.shape[2]
286
print("xxv/validation shape: {}, Seqlength: {}, Features: {}".format(xxv.shape[0],seqlength,features))
287
# #plot validation ecgs 
288
# with PdfPages('ecgs_xxv.pdf') as pdf:
289
#   for i in range( xxv.shape[0] ): 
290
#       print (i)
291
#       plotecg(xxv[i,:,:],yyv[i,:,:],0,yyv.shape[1])
292
#       pdf.savefig()
293
#       plt.close()
294
295
# call keras/tensorflow and build lstm model 
296
KTF.set_session(get_session())
297
with tf.device('/cpu:0'): #switch to /cpu:0 to use cpu 
298
    if not os.path.isfile('model.h5'):
299
        model = getmodel() # build model
300
        model.fit(xxt, yyt, batch_size=32, epochs=100, verbose=1) # train the model
301
        model.save('model.h5')
302
303
    model = load_model('model.h5')
304
    score, acc = model.evaluate(xxv, yyv, batch_size=4, verbose=1)
305
    print('Test score: {} , Test accuracy: {}'.format(score, acc))
306
    
307
    # predict
308
    yy_predicted = model.predict(xxv) 
309
310
    # maximize probabilities of prediction.
311
    for i in range(yyv.shape[0]): 
312
        b = np.zeros_like(yy_predicted[i,:,:])
313
        b[np.arange(len(yy_predicted[i,:,:])), yy_predicted[i,:,:].argmax(1)] = 1
314
        yy_predicted[i,:,:] = b
315
316
    # plot: 
317
    with PdfPages('ecg.pdf') as pdf:
318
        for i in range( xxv.shape[0] ): 
319
            print (i)
320
            plotecg_validation(xxv[i,:,:],yy_predicted[i,:,:],yyv[i,:,:],0,yy_predicted.shape[1])  # top = predicted, bottom=true
321
            pdf.savefig()
322
            plt.close()
323
324
    #plotecg(xv[1,:,:],yv[1,:,:],0,yv.shape[1]) ## plot first seq