--- a +++ b/pre_proc.py @@ -0,0 +1,233 @@ +import wfdb +import matplotlib.pyplot as plt +import numpy as np +from hrv.filters import quotient, moving_median +from scipy import interpolate +from tqdm import tqdm +import pickle +import os +FS = 100.0 + +# From https://github.com/rhenanbartels/hrv/blob/develop/hrv/classical.py +def create_time_info(rri): + rri_time = np.cumsum(rri) / 1000.0 # make it seconds + return rri_time - rri_time[0] # force it to start at zero + +def create_interp_time(rri, fs): + time_rri = create_time_info(rri) + return np.arange(0, time_rri[-1], 1 / float(fs)) + +def interp_cubic_spline(rri, fs): + time_rri = create_time_info(rri) + time_rri_interp = create_interp_time(rri, fs) + tck = interpolate.splrep(time_rri, rri, s=0) + rri_interp = interpolate.splev(time_rri_interp, tck, der=0) + return time_rri_interp, rri_interp + +def interp_cubic_spline_qrs(qrs_index, qrs_amp, fs): + time_qrs = qrs_index / float(FS) + time_qrs = time_qrs - time_qrs[0] + time_qrs_interp = np.arange(0, time_qrs[-1], 1/float(fs)) + tck = interpolate.splrep(time_qrs, qrs_amp, s=0) + qrs_interp = interpolate.splev(time_qrs_interp, tck, der=0) + return time_qrs_interp, qrs_interp + +data_path = './data/' +train_data_name = ['a02', 'a03', 'a04', 'a05', + 'a06', 'a07', 'a08', 'a09', 'a10', + 'a11', 'a12', 'a13', 'a14', 'a15', + 'a16', 'a17', 'a18', 'a19', + 'b02', 'b03', 'b04', + 'c02', 'c03', 'c04', 'c05', + 'c06', 'c07', 'c08', 'c09', + ] +val_data_name = ['a01', 'b01', 'c01'] +test_data_name = ['a20','b05','c10'] +age = [51, 38, 54, 52, 58, + 63, 44, 51, 52, 58, + 58, 52, 51, 51, 60, + 44, 40, 52, 55, 58, + 44, 53, 53, 42, 52, + 31, 37, 39, 41, 28, + 28, 30, 42, 37, 27] +sex = [1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 0, 1, 1, 1, 1, + 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1] + + +def get_qrs_amp(ecg, qrs): + interval = int(FS * 0.250) + qrs_amp = [] + for index in range(len(qrs)): + curr_qrs = qrs[index] + amp = np.max(ecg[curr_qrs-interval:curr_qrs+interval]) + qrs_amp.append(amp) + + return qrs_amp + +MARGIN = 10 +FS_INTP = 4 +MAX_HR = 300.0 +MIN_HR = 20.0 +MIN_RRI = 1.0 / (MAX_HR / 60.0) * 1000 +MAX_RRI = 1.0 / (MIN_HR / 60.0) * 1000 +train_input_array = [] +train_label_array = [] + +for data_index in range(len(train_data_name)): + print (train_data_name[data_index]) + win_num = len(wfdb.rdann(os.path.join(data_path,train_data_name[data_index]), 'apn').symbol) + signals, fields = wfdb.rdsamp(os.path.join(data_path,train_data_name[data_index])) + for index in tqdm(range(1, win_num)): + samp_from = index * 60 * FS # 60 seconds + samp_to = samp_from + 60 * FS # 60 seconds + + qrs_ann = wfdb.rdann(data_path + train_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample + apn_ann = wfdb.rdann(data_path + train_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol + + qrs_amp = get_qrs_amp(signals, qrs_ann) + + rri = np.diff(qrs_ann) + rri_ms = rri.astype('float') / FS * 1000.0 + try: + rri_filt = moving_median(rri_ms) + + if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI): + time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP) + qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP) + rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] + qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))] + #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] + + if len(rri_intp) != (FS_INTP * 60): + skip = 1 + else: + skip = 0 + + if skip == 0: + rri_intp = rri_intp - np.mean(rri_intp) + qrs_intp = qrs_intp - np.mean(qrs_intp) + if apn_ann[0] == 'N': # Normal + label = 0.0 + elif apn_ann[0] == 'A': # Apnea + label = 1.0 + else: + label = 2.0 + + train_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]]) + train_label_array.append(label) + except: + hrv_module_error = 1 +with open('train_input.pickle','wb') as f: + pickle.dump(train_input_array, f) +with open('train_label.pickle','wb') as f: + pickle.dump(train_label_array, f) + + +val_input_array = [] +val_label_array = [] +for data_index in range(len(val_data_name)): + print (val_data_name[data_index]) + win_num = len(wfdb.rdann(os.path.join(data_path,val_data_name[data_index]), 'apn').symbol) + signals, fields = wfdb.rdsamp(os.path.join(data_path,val_data_name[data_index])) + for index in tqdm(range(1, win_num)): + samp_from = index * 60 * FS # 60 seconds + samp_to = samp_from + 60 * FS # 60 seconds + + qrs_ann = wfdb.rdann(data_path + val_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample + apn_ann = wfdb.rdann(data_path + val_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol + + qrs_amp = get_qrs_amp(signals, qrs_ann) + + rri = np.diff(qrs_ann) + rri_ms = rri.astype('float') / FS * 1000.0 + try: + rri_filt = moving_median(rri_ms) + + if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI): + time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP) + qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP) + rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] + qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))] + #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] + + if len(rri_intp) != (FS_INTP * 60): + skip = 1 + else: + skip = 0 + + if skip == 0: + rri_intp = rri_intp - np.mean(rri_intp) + qrs_intp = qrs_intp - np.mean(qrs_intp) + if apn_ann[0] == 'N': # Normal + label = 0.0 + elif apn_ann[0] == 'A': # Apnea + label = 1.0 + else: + label = 2.0 + + val_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]]) + val_label_array.append(label) + except: + hrv_module_error = 1 + +with open('val_input.pickle','wb') as f: + pickle.dump(val_input_array, f) +with open('val_label.pickle','wb') as f: + pickle.dump(val_label_array, f) + +test_input_array = [] +test_label_array = [] +for data_index in range(len(test_data_name)): + print (test_data_name[data_index]) + win_num = len(wfdb.rdann(os.path.join(data_path,test_data_name[data_index]), 'apn').symbol) + signals, fields = wfdb.rdsamp(os.path.join(data_path,test_data_name[data_index])) + for index in tqdm(range(1, win_num)): + samp_from = index * 60 * FS # 60 seconds + samp_to = samp_from + 60 * FS # 60 seconds + + qrs_ann = wfdb.rdann(data_path + test_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample + apn_ann = wfdb.rdann(data_path + test_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol + + qrs_amp = get_qrs_amp(signals, qrs_ann) + + rri = np.diff(qrs_ann) + rri_ms = rri.astype('float') / FS * 1000.0 + try: + rri_filt = moving_median(rri_ms) + + if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI): + time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP) + qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP) + rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] + qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))] + #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] + + if len(rri_intp) != (FS_INTP * 60): + skip = 1 + else: + skip = 0 + + if skip == 0: + rri_intp = rri_intp - np.mean(rri_intp) + qrs_intp = qrs_intp - np.mean(qrs_intp) + if apn_ann[0] == 'N': # Normal + label = 0.0 + elif apn_ann[0] == 'A': # Apnea + label = 1.0 + else: + label = 2.0 + + test_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]]) + test_label_array.append(label) + except: + hrv_module_error = 1 + +with open('test_input.pickle','wb') as f: + pickle.dump(test_input_array, f) +with open('test_label.pickle','wb') as f: + pickle.dump(test_label_array, f) \ No newline at end of file