|
a |
|
b/pre_proc.py |
|
|
1 |
import wfdb |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
import numpy as np |
|
|
4 |
from hrv.filters import quotient, moving_median |
|
|
5 |
from scipy import interpolate |
|
|
6 |
from tqdm import tqdm |
|
|
7 |
import pickle |
|
|
8 |
import os |
|
|
9 |
FS = 100.0 |
|
|
10 |
|
|
|
11 |
# From https://github.com/rhenanbartels/hrv/blob/develop/hrv/classical.py |
|
|
12 |
def create_time_info(rri): |
|
|
13 |
rri_time = np.cumsum(rri) / 1000.0 # make it seconds |
|
|
14 |
return rri_time - rri_time[0] # force it to start at zero |
|
|
15 |
|
|
|
16 |
def create_interp_time(rri, fs): |
|
|
17 |
time_rri = create_time_info(rri) |
|
|
18 |
return np.arange(0, time_rri[-1], 1 / float(fs)) |
|
|
19 |
|
|
|
20 |
def interp_cubic_spline(rri, fs): |
|
|
21 |
time_rri = create_time_info(rri) |
|
|
22 |
time_rri_interp = create_interp_time(rri, fs) |
|
|
23 |
tck = interpolate.splrep(time_rri, rri, s=0) |
|
|
24 |
rri_interp = interpolate.splev(time_rri_interp, tck, der=0) |
|
|
25 |
return time_rri_interp, rri_interp |
|
|
26 |
|
|
|
27 |
def interp_cubic_spline_qrs(qrs_index, qrs_amp, fs): |
|
|
28 |
time_qrs = qrs_index / float(FS) |
|
|
29 |
time_qrs = time_qrs - time_qrs[0] |
|
|
30 |
time_qrs_interp = np.arange(0, time_qrs[-1], 1/float(fs)) |
|
|
31 |
tck = interpolate.splrep(time_qrs, qrs_amp, s=0) |
|
|
32 |
qrs_interp = interpolate.splev(time_qrs_interp, tck, der=0) |
|
|
33 |
return time_qrs_interp, qrs_interp |
|
|
34 |
|
|
|
35 |
data_path = './data/' |
|
|
36 |
train_data_name = ['a02', 'a03', 'a04', 'a05', |
|
|
37 |
'a06', 'a07', 'a08', 'a09', 'a10', |
|
|
38 |
'a11', 'a12', 'a13', 'a14', 'a15', |
|
|
39 |
'a16', 'a17', 'a18', 'a19', |
|
|
40 |
'b02', 'b03', 'b04', |
|
|
41 |
'c02', 'c03', 'c04', 'c05', |
|
|
42 |
'c06', 'c07', 'c08', 'c09', |
|
|
43 |
] |
|
|
44 |
val_data_name = ['a01', 'b01', 'c01'] |
|
|
45 |
test_data_name = ['a20','b05','c10'] |
|
|
46 |
age = [51, 38, 54, 52, 58, |
|
|
47 |
63, 44, 51, 52, 58, |
|
|
48 |
58, 52, 51, 51, 60, |
|
|
49 |
44, 40, 52, 55, 58, |
|
|
50 |
44, 53, 53, 42, 52, |
|
|
51 |
31, 37, 39, 41, 28, |
|
|
52 |
28, 30, 42, 37, 27] |
|
|
53 |
sex = [1, 1, 1, 1, 1, |
|
|
54 |
1, 1, 1, 1, 1, |
|
|
55 |
1, 1, 1, 1, 1, |
|
|
56 |
1, 1, 1, 1, 1, |
|
|
57 |
0, 1, 1, 1, 1, |
|
|
58 |
1, 1, 1, 0, 0, |
|
|
59 |
0, 0, 1, 1, 1] |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def get_qrs_amp(ecg, qrs): |
|
|
63 |
interval = int(FS * 0.250) |
|
|
64 |
qrs_amp = [] |
|
|
65 |
for index in range(len(qrs)): |
|
|
66 |
curr_qrs = qrs[index] |
|
|
67 |
amp = np.max(ecg[curr_qrs-interval:curr_qrs+interval]) |
|
|
68 |
qrs_amp.append(amp) |
|
|
69 |
|
|
|
70 |
return qrs_amp |
|
|
71 |
|
|
|
72 |
MARGIN = 10 |
|
|
73 |
FS_INTP = 4 |
|
|
74 |
MAX_HR = 300.0 |
|
|
75 |
MIN_HR = 20.0 |
|
|
76 |
MIN_RRI = 1.0 / (MAX_HR / 60.0) * 1000 |
|
|
77 |
MAX_RRI = 1.0 / (MIN_HR / 60.0) * 1000 |
|
|
78 |
train_input_array = [] |
|
|
79 |
train_label_array = [] |
|
|
80 |
|
|
|
81 |
for data_index in range(len(train_data_name)): |
|
|
82 |
print (train_data_name[data_index]) |
|
|
83 |
win_num = len(wfdb.rdann(os.path.join(data_path,train_data_name[data_index]), 'apn').symbol) |
|
|
84 |
signals, fields = wfdb.rdsamp(os.path.join(data_path,train_data_name[data_index])) |
|
|
85 |
for index in tqdm(range(1, win_num)): |
|
|
86 |
samp_from = index * 60 * FS # 60 seconds |
|
|
87 |
samp_to = samp_from + 60 * FS # 60 seconds |
|
|
88 |
|
|
|
89 |
qrs_ann = wfdb.rdann(data_path + train_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample |
|
|
90 |
apn_ann = wfdb.rdann(data_path + train_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol |
|
|
91 |
|
|
|
92 |
qrs_amp = get_qrs_amp(signals, qrs_ann) |
|
|
93 |
|
|
|
94 |
rri = np.diff(qrs_ann) |
|
|
95 |
rri_ms = rri.astype('float') / FS * 1000.0 |
|
|
96 |
try: |
|
|
97 |
rri_filt = moving_median(rri_ms) |
|
|
98 |
|
|
|
99 |
if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI): |
|
|
100 |
time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP) |
|
|
101 |
qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP) |
|
|
102 |
rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] |
|
|
103 |
qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))] |
|
|
104 |
#time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] |
|
|
105 |
|
|
|
106 |
if len(rri_intp) != (FS_INTP * 60): |
|
|
107 |
skip = 1 |
|
|
108 |
else: |
|
|
109 |
skip = 0 |
|
|
110 |
|
|
|
111 |
if skip == 0: |
|
|
112 |
rri_intp = rri_intp - np.mean(rri_intp) |
|
|
113 |
qrs_intp = qrs_intp - np.mean(qrs_intp) |
|
|
114 |
if apn_ann[0] == 'N': # Normal |
|
|
115 |
label = 0.0 |
|
|
116 |
elif apn_ann[0] == 'A': # Apnea |
|
|
117 |
label = 1.0 |
|
|
118 |
else: |
|
|
119 |
label = 2.0 |
|
|
120 |
|
|
|
121 |
train_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]]) |
|
|
122 |
train_label_array.append(label) |
|
|
123 |
except: |
|
|
124 |
hrv_module_error = 1 |
|
|
125 |
with open('train_input.pickle','wb') as f: |
|
|
126 |
pickle.dump(train_input_array, f) |
|
|
127 |
with open('train_label.pickle','wb') as f: |
|
|
128 |
pickle.dump(train_label_array, f) |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
val_input_array = [] |
|
|
132 |
val_label_array = [] |
|
|
133 |
for data_index in range(len(val_data_name)): |
|
|
134 |
print (val_data_name[data_index]) |
|
|
135 |
win_num = len(wfdb.rdann(os.path.join(data_path,val_data_name[data_index]), 'apn').symbol) |
|
|
136 |
signals, fields = wfdb.rdsamp(os.path.join(data_path,val_data_name[data_index])) |
|
|
137 |
for index in tqdm(range(1, win_num)): |
|
|
138 |
samp_from = index * 60 * FS # 60 seconds |
|
|
139 |
samp_to = samp_from + 60 * FS # 60 seconds |
|
|
140 |
|
|
|
141 |
qrs_ann = wfdb.rdann(data_path + val_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample |
|
|
142 |
apn_ann = wfdb.rdann(data_path + val_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol |
|
|
143 |
|
|
|
144 |
qrs_amp = get_qrs_amp(signals, qrs_ann) |
|
|
145 |
|
|
|
146 |
rri = np.diff(qrs_ann) |
|
|
147 |
rri_ms = rri.astype('float') / FS * 1000.0 |
|
|
148 |
try: |
|
|
149 |
rri_filt = moving_median(rri_ms) |
|
|
150 |
|
|
|
151 |
if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI): |
|
|
152 |
time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP) |
|
|
153 |
qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP) |
|
|
154 |
rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] |
|
|
155 |
qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))] |
|
|
156 |
#time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] |
|
|
157 |
|
|
|
158 |
if len(rri_intp) != (FS_INTP * 60): |
|
|
159 |
skip = 1 |
|
|
160 |
else: |
|
|
161 |
skip = 0 |
|
|
162 |
|
|
|
163 |
if skip == 0: |
|
|
164 |
rri_intp = rri_intp - np.mean(rri_intp) |
|
|
165 |
qrs_intp = qrs_intp - np.mean(qrs_intp) |
|
|
166 |
if apn_ann[0] == 'N': # Normal |
|
|
167 |
label = 0.0 |
|
|
168 |
elif apn_ann[0] == 'A': # Apnea |
|
|
169 |
label = 1.0 |
|
|
170 |
else: |
|
|
171 |
label = 2.0 |
|
|
172 |
|
|
|
173 |
val_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]]) |
|
|
174 |
val_label_array.append(label) |
|
|
175 |
except: |
|
|
176 |
hrv_module_error = 1 |
|
|
177 |
|
|
|
178 |
with open('val_input.pickle','wb') as f: |
|
|
179 |
pickle.dump(val_input_array, f) |
|
|
180 |
with open('val_label.pickle','wb') as f: |
|
|
181 |
pickle.dump(val_label_array, f) |
|
|
182 |
|
|
|
183 |
test_input_array = [] |
|
|
184 |
test_label_array = [] |
|
|
185 |
for data_index in range(len(test_data_name)): |
|
|
186 |
print (test_data_name[data_index]) |
|
|
187 |
win_num = len(wfdb.rdann(os.path.join(data_path,test_data_name[data_index]), 'apn').symbol) |
|
|
188 |
signals, fields = wfdb.rdsamp(os.path.join(data_path,test_data_name[data_index])) |
|
|
189 |
for index in tqdm(range(1, win_num)): |
|
|
190 |
samp_from = index * 60 * FS # 60 seconds |
|
|
191 |
samp_to = samp_from + 60 * FS # 60 seconds |
|
|
192 |
|
|
|
193 |
qrs_ann = wfdb.rdann(data_path + test_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample |
|
|
194 |
apn_ann = wfdb.rdann(data_path + test_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol |
|
|
195 |
|
|
|
196 |
qrs_amp = get_qrs_amp(signals, qrs_ann) |
|
|
197 |
|
|
|
198 |
rri = np.diff(qrs_ann) |
|
|
199 |
rri_ms = rri.astype('float') / FS * 1000.0 |
|
|
200 |
try: |
|
|
201 |
rri_filt = moving_median(rri_ms) |
|
|
202 |
|
|
|
203 |
if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI): |
|
|
204 |
time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP) |
|
|
205 |
qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP) |
|
|
206 |
rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] |
|
|
207 |
qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))] |
|
|
208 |
#time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))] |
|
|
209 |
|
|
|
210 |
if len(rri_intp) != (FS_INTP * 60): |
|
|
211 |
skip = 1 |
|
|
212 |
else: |
|
|
213 |
skip = 0 |
|
|
214 |
|
|
|
215 |
if skip == 0: |
|
|
216 |
rri_intp = rri_intp - np.mean(rri_intp) |
|
|
217 |
qrs_intp = qrs_intp - np.mean(qrs_intp) |
|
|
218 |
if apn_ann[0] == 'N': # Normal |
|
|
219 |
label = 0.0 |
|
|
220 |
elif apn_ann[0] == 'A': # Apnea |
|
|
221 |
label = 1.0 |
|
|
222 |
else: |
|
|
223 |
label = 2.0 |
|
|
224 |
|
|
|
225 |
test_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]]) |
|
|
226 |
test_label_array.append(label) |
|
|
227 |
except: |
|
|
228 |
hrv_module_error = 1 |
|
|
229 |
|
|
|
230 |
with open('test_input.pickle','wb') as f: |
|
|
231 |
pickle.dump(test_input_array, f) |
|
|
232 |
with open('test_label.pickle','wb') as f: |
|
|
233 |
pickle.dump(test_label_array, f) |