--- a +++ b/ecg_annotation.py @@ -0,0 +1,324 @@ +#import matplotlib with pdf as backend +import matplotlib +matplotlib.use('PDF') +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages + +import wfdb +import os +import numpy as np +import math +import sys +import scipy.stats as st +import glob, os +from os.path import basename + + +import tensorflow as tf +from keras.layers import Dense,Activation,Dropout +from keras.layers import LSTM,Bidirectional #could try TimeDistributed(Dense(...)) +from keras.models import Sequential, load_model +from keras import optimizers,regularizers +from keras.layers.normalization import BatchNormalization +import keras.backend.tensorflow_backend as KTF +np.random.seed(0) + +# functions +def get_ecg_data(datfile): + ## convert .dat/q1c to numpy arrays + recordname=os.path.basename(datfile).split(".dat")[0] + recordpath=os.path.dirname(datfile) + cwd=os.getcwd() + os.chdir(recordpath) ## somehow it only works if you chdir. + + annotator='q1c' + annotation = wfdb.rdann(recordname, extension=annotator, sampfrom=0,sampto = None, pbdir=None) + Lstannot=list(zip(annotation.sample,annotation.symbol,annotation.aux_note)) + + FirstLstannot=min( i[0] for i in Lstannot) + LastLstannot=max( i[0] for i in Lstannot)-1 + print("first-last annotation:", FirstLstannot,LastLstannot) + + record = wfdb.rdsamp(recordname, sampfrom=FirstLstannot,sampto = LastLstannot) #wfdb.showanncodes() + annotation = wfdb.rdann(recordname, annotator, sampfrom=FirstLstannot,sampto = LastLstannot) ## get annotation between first and last. + annotation2 = wfdb.Annotation(recordname='sel32', extension='niek', sample=(annotation.sample-FirstLstannot), symbol = annotation.symbol, aux_note=annotation.aux_note) + + Vctrecord=np.transpose(record.p_signals) + VctAnnotationHot=np.zeros( (6,len(Vctrecord[1])), dtype=np.int) + VctAnnotationHot[5]=1 ## inverse of the others + #print("ecg, 2 lead of shape" , Vctrecord.shape) + #print("VctAnnotationHot of shape" , VctAnnotationHot.shape) + #print('plotting extracted signal with annotation') + #wfdb.plotrec(record, annotation=annotation2, title='Record 100 from MIT-BIH Arrhythmia Database', timeunits = 'seconds') + + VctAnnotations=list(zip(annotation2.sample,annotation2.symbol)) ## zip coordinates + annotations (N),(t) etc) + #print(VctAnnotations) + for i in range(len(VctAnnotations)): + #print(VctAnnotations[i]) # Print to display annotations of an ecg + try: + + if VctAnnotations[i][1]=="p": + if VctAnnotations[i-1][1]=="(": + pstart=VctAnnotations[i-1][0] + if VctAnnotations[i+1][1]==")": + pend=VctAnnotations[i+1][0] + if VctAnnotations[i+3][1]=="N": + rpos=VctAnnotations[i+3][0] + if VctAnnotations[i+2][1]=="(": + qpos=VctAnnotations[i+2][0] + if VctAnnotations[i+4][1]==")": + spos=VctAnnotations[i+4][0] + for ii in range(0,8): ## search for t (sometimes the "(" for the t is missing ) + if VctAnnotations[i+ii][1]=="t": + tpos=VctAnnotations[i+ii][0] + if VctAnnotations[i+ii+1][1]==")": + tendpos=VctAnnotations[i+ii+1][0] + # #print(ppos,qpos,rpos,spos,tendpos) + VctAnnotationHot[0][pstart:pend]=1 #P segment + VctAnnotationHot[1][pend:qpos]=1 #part "nothing" between P and Q, previously left unnanotated, but categorical probably can't deal with that + VctAnnotationHot[2][qpos:rpos]=1 #QR + VctAnnotationHot[3][rpos:spos]=1 #RS + VctAnnotationHot[4][spos:tendpos]=1 #ST (from end of S to end of T) + VctAnnotationHot[5][pstart:tendpos]=0 #tendpos:pstart becomes 1, because it is inverted above + except IndexError: + pass + + Vctrecord=np.transpose(Vctrecord) # transpose to (timesteps,feat) + VctAnnotationHot=np.transpose(VctAnnotationHot) + os.chdir(cwd) + return Vctrecord, VctAnnotationHot + + + +def splitseq(x,n,o): + #split seq; should be optimized so that remove_seq_gaps is not needed. + upper=math.ceil( x.shape[0] / n) *n + print("splitting on",n,"with overlap of ",o, "total datapoints:",x.shape[0],"; upper:",upper) + for i in range(0,upper,n): + #print(i) + if i==0: + padded=np.zeros( ( o+n+o,x.shape[1]) ) ## pad with 0's on init + padded[o:,:x.shape[1]] = x[i:i+n+o,:] + xpart=padded + else: + xpart=x[i-o:i+n+o,:] + if xpart.shape[0]<i: + + padded=np.zeros( (o+n+o,xpart.shape[1]) ) ## pad with 0's on end of seq + padded[:xpart.shape[0],:xpart.shape[1]] = xpart + xpart=padded + + xpart=np.expand_dims(xpart,0)## add one dimension; so that you get shape (samples,timesteps,features) + try: + xx=np.vstack( (xx,xpart) ) + except UnboundLocalError: ## on init + xx=xpart + print("output: ",xx.shape) + return(xx) + +def remove_seq_gaps(x,y): + #remove parts that are not annotated <- not ideal, but quickest for now. + window=150 + c=0 + cutout=[] + include=[] + print("filterering.") + print("before shape x,y",x.shape,y.shape) + for i in range(y.shape[0]): + + c=c+1 + if c<window : + include.append(i) + if sum(y[i,0:5])>0: + c=0 + if c >= window: + #print ('filtering') + pass + x,y=x[include,:],y[include,:] + print(" after shape x,y",x.shape,y.shape) + return(x,y) + + +def normalizesignal(x): + x=st.zscore(x, ddof=0) + return x +def normalizesignal_array(x): + for i in range(x.shape[0]): + x[i]=st.zscore(x[i], axis=0, ddof=0) + return x + +def plotecg(x,y,begin,end): + #helper to plot ecg + plt.figure(1,figsize=(11.69,8.27)) + plt.subplot(211) + plt.plot(x[begin:end,0]) + plt.subplot(211) + plt.plot(y[begin:end,0]) + plt.subplot(211) + plt.plot(y[begin:end,1]) + plt.subplot(211) + plt.plot(y[begin:end,2]) + plt.subplot(211) + plt.plot(y[begin:end,3]) + plt.subplot(211) + plt.plot(y[begin:end,4]) + plt.subplot(211) + plt.plot(y[begin:end,5]) + + plt.subplot(212) + plt.plot(x[begin:end,1]) + plt.show() + +def plotecg_validation(x,y_true,y_pred,begin,end): + #helper to plot ecg + plt.figure(1,figsize=(11.69,8.27)) + plt.subplot(211) + plt.plot(x[begin:end,0]) + plt.subplot(211) + plt.plot(y_pred[begin:end,0]) + plt.subplot(211) + plt.plot(y_pred[begin:end,1]) + plt.subplot(211) + plt.plot(y_pred[begin:end,2]) + plt.subplot(211) + plt.plot(y_pred[begin:end,3]) + plt.subplot(211) + plt.plot(y_pred[begin:end,4]) + plt.subplot(211) + plt.plot(y_pred[begin:end,5]) + + plt.subplot(212) + plt.plot(x[begin:end,1]) + plt.subplot(212) + plt.plot(y_true[begin:end,0]) + plt.subplot(212) + plt.plot(y_true[begin:end,1]) + plt.subplot(212) + plt.plot(y_true[begin:end,2]) + plt.subplot(212) + plt.plot(y_true[begin:end,3]) + plt.subplot(212) + plt.plot(y_true[begin:end,4]) + plt.subplot(212) + plt.plot(y_true[begin:end,5]) + +def LoaddDatFiles(datfiles): + for datfile in datfiles: + print(datfile) + if basename(datfile).split(".",1)[0] in exclude: + continue + qf=os.path.splitext(datfile)[0]+'.q1c' + if os.path.isfile(qf): + #print("yes",qf,datfile) + x,y=get_ecg_data(datfile) + x,y=remove_seq_gaps(x,y) + + x,y=splitseq(x,1000,150),splitseq(y,1000,150) ## create equal sized numpy arrays of n size and overlap of o + + x = normalizesignal_array(x) + ## todo; add noise, shuffle leads etc. ? + try: ## concat + xx=np.vstack( (xx,x) ) + yy=np.vstack( (yy,y) ) + except NameError: ## if xx does not exist yet (on init) + xx = x + yy = y + return(xx,yy) + +def unison_shuffled_copies(a, b): + assert len(a) == len(b) + p = np.random.permutation(len(a)) + return a[p], b[p] + +def get_session(gpu_fraction=0.8): + #allocate % of gpu memory. + num_threads = os.environ.get('OMP_NUM_THREADS') + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction) + if num_threads: + return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, intra_op_parallelism_threads=num_threads)) + else: + return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) + +def getmodel(): + model = Sequential() + model.add(Dense(32,W_regularizer=regularizers.l2(l=0.01), input_shape=(seqlength, features))) + model.add(Bidirectional(LSTM(32, return_sequences=True)))#, input_shape=(seqlength, features)) ) ### bidirectional ---><--- + model.add(Dropout(0.2)) + model.add(BatchNormalization()) + model.add(Dense(64, activation='relu',W_regularizer=regularizers.l2(l=0.01))) + model.add(Dropout(0.2)) + model.add(BatchNormalization()) + model.add(Dense(dimout, activation='softmax')) + adam = optimizers.adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) + model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy']) + print(model.summary()) + return(model) + +################################################################## +################################################################## +qtdbpath=sys.argv[1] ## first argument = qtdb database from physionet. +perct=0.81 #percentage training +percv=0.19 #percentage validation + +exclude = set() +exclude.update(["sel35","sel36","sel37","sel50","sel102","sel104","sel221","sel232", "sel310"])# no P annotated: +################################################################## +# datfile=qtdbpath+"sel49.dat" ## single ECG to test if loading works. +# x,y=get_ecg_data(datfile) +# print(x.shape,y.shape) +# # for i in range(y.shape[0]): #Invert QT-label to actually represent QT. Does give overlapping labels +# # y[i][0] = 1 - y[i][0] +# plotecg(x,y,0,y.shape[0]) ## plot all +# x,y=remove_seq_gaps(x,y) ## remove 'annotation gaps' +# plotecg(x,y,0,y.shape[0]) ## plot all +# x,y=splitseq(x,750,150),splitseq(y,750,150) ## create equal sized numpy arrays of n size and overlap of o +# exit() +################################################################## + +# load data +datfiles=glob.glob(qtdbpath+"*.dat") +xxt,yyt=LoaddDatFiles(datfiles[ :round(len(datfiles)*perct) ]) # training data. +xxt,yyt=unison_shuffled_copies(xxt,yyt) ### shuffle +xxv,yyv=LoaddDatFiles(datfiles[ -round(len(datfiles)*percv): ] ) ## validation data. +seqlength=xxt.shape[1] +features=xxt.shape[2] +dimout=yyt.shape[2] +print("xxv/validation shape: {}, Seqlength: {}, Features: {}".format(xxv.shape[0],seqlength,features)) +# #plot validation ecgs +# with PdfPages('ecgs_xxv.pdf') as pdf: +# for i in range( xxv.shape[0] ): +# print (i) +# plotecg(xxv[i,:,:],yyv[i,:,:],0,yyv.shape[1]) +# pdf.savefig() +# plt.close() + +# call keras/tensorflow and build lstm model +KTF.set_session(get_session()) +with tf.device('/cpu:0'): #switch to /cpu:0 to use cpu + if not os.path.isfile('model.h5'): + model = getmodel() # build model + model.fit(xxt, yyt, batch_size=32, epochs=100, verbose=1) # train the model + model.save('model.h5') + + model = load_model('model.h5') + score, acc = model.evaluate(xxv, yyv, batch_size=4, verbose=1) + print('Test score: {} , Test accuracy: {}'.format(score, acc)) + + # predict + yy_predicted = model.predict(xxv) + + # maximize probabilities of prediction. + for i in range(yyv.shape[0]): + b = np.zeros_like(yy_predicted[i,:,:]) + b[np.arange(len(yy_predicted[i,:,:])), yy_predicted[i,:,:].argmax(1)] = 1 + yy_predicted[i,:,:] = b + + # plot: + with PdfPages('ecg.pdf') as pdf: + for i in range( xxv.shape[0] ): + print (i) + plotecg_validation(xxv[i,:,:],yy_predicted[i,:,:],yyv[i,:,:],0,yy_predicted.shape[1]) # top = predicted, bottom=true + pdf.savefig() + plt.close() + + #plotecg(xv[1,:,:],yv[1,:,:],0,yv.shape[1]) ## plot first seq