Diff of /pre_proc.py [000000] .. [56ff0f]

Switch to unified view

a b/pre_proc.py
1
import wfdb
2
import matplotlib.pyplot as plt
3
import numpy as np
4
from hrv.filters import quotient, moving_median
5
from scipy import interpolate
6
from tqdm import tqdm
7
import pickle
8
import os
9
FS = 100.0
10
11
# From https://github.com/rhenanbartels/hrv/blob/develop/hrv/classical.py
12
def create_time_info(rri):
13
    rri_time = np.cumsum(rri) / 1000.0  # make it seconds
14
    return rri_time - rri_time[0]   # force it to start at zero
15
16
def create_interp_time(rri, fs):
17
    time_rri = create_time_info(rri)
18
    return np.arange(0, time_rri[-1], 1 / float(fs))
19
20
def interp_cubic_spline(rri, fs):
21
    time_rri = create_time_info(rri)
22
    time_rri_interp = create_interp_time(rri, fs)
23
    tck = interpolate.splrep(time_rri, rri, s=0)
24
    rri_interp = interpolate.splev(time_rri_interp, tck, der=0)
25
    return time_rri_interp, rri_interp
26
27
def interp_cubic_spline_qrs(qrs_index, qrs_amp, fs):
28
    time_qrs = qrs_index / float(FS)
29
    time_qrs = time_qrs - time_qrs[0]
30
    time_qrs_interp = np.arange(0, time_qrs[-1], 1/float(fs))
31
    tck = interpolate.splrep(time_qrs, qrs_amp, s=0)
32
    qrs_interp = interpolate.splev(time_qrs_interp, tck, der=0)
33
    return time_qrs_interp, qrs_interp
34
35
data_path = './data/'
36
train_data_name = ['a02', 'a03', 'a04', 'a05',
37
             'a06', 'a07', 'a08', 'a09', 'a10',
38
             'a11', 'a12', 'a13', 'a14', 'a15',
39
             'a16', 'a17', 'a18', 'a19',
40
             'b02', 'b03', 'b04',
41
             'c02', 'c03', 'c04', 'c05',
42
             'c06', 'c07', 'c08', 'c09',
43
             ]
44
val_data_name = ['a01', 'b01', 'c01']
45
test_data_name = ['a20','b05','c10']
46
age = [51, 38, 54, 52, 58,
47
       63, 44, 51, 52, 58,
48
       58, 52, 51, 51, 60,
49
       44, 40, 52, 55, 58,
50
       44, 53, 53, 42, 52,
51
       31, 37, 39, 41, 28,
52
       28, 30, 42, 37, 27]
53
sex = [1, 1, 1, 1, 1,
54
       1, 1, 1, 1, 1,
55
       1, 1, 1, 1, 1,
56
       1, 1, 1, 1, 1,
57
       0, 1, 1, 1, 1,
58
       1, 1, 1, 0, 0,
59
       0, 0, 1, 1, 1]
60
61
62
def get_qrs_amp(ecg, qrs):
63
    interval = int(FS * 0.250)
64
    qrs_amp = []
65
    for index in range(len(qrs)):
66
        curr_qrs = qrs[index]
67
        amp = np.max(ecg[curr_qrs-interval:curr_qrs+interval])
68
        qrs_amp.append(amp)
69
70
    return qrs_amp
71
72
MARGIN = 10
73
FS_INTP = 4
74
MAX_HR = 300.0
75
MIN_HR = 20.0
76
MIN_RRI = 1.0 / (MAX_HR / 60.0) * 1000
77
MAX_RRI = 1.0 / (MIN_HR / 60.0) * 1000
78
train_input_array = []
79
train_label_array = []
80
81
for data_index in range(len(train_data_name)):
82
    print (train_data_name[data_index])
83
    win_num = len(wfdb.rdann(os.path.join(data_path,train_data_name[data_index]), 'apn').symbol)
84
    signals, fields = wfdb.rdsamp(os.path.join(data_path,train_data_name[data_index]))
85
    for index in tqdm(range(1, win_num)):
86
        samp_from = index * 60 * FS # 60 seconds
87
        samp_to = samp_from + 60 * FS  # 60 seconds
88
89
        qrs_ann = wfdb.rdann(data_path + train_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample
90
        apn_ann = wfdb.rdann(data_path + train_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol
91
92
        qrs_amp = get_qrs_amp(signals, qrs_ann)
93
94
        rri = np.diff(qrs_ann)
95
        rri_ms = rri.astype('float') / FS * 1000.0
96
        try:
97
            rri_filt = moving_median(rri_ms)
98
99
            if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI):
100
                time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP)
101
                qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP)
102
                rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
103
                qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))]
104
                #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
105
106
                if len(rri_intp) != (FS_INTP * 60):
107
                    skip = 1
108
                else:
109
                    skip = 0
110
111
                if skip == 0:
112
                    rri_intp = rri_intp - np.mean(rri_intp)
113
                    qrs_intp = qrs_intp - np.mean(qrs_intp)
114
                    if apn_ann[0] == 'N': # Normal
115
                        label = 0.0
116
                    elif apn_ann[0] == 'A': # Apnea
117
                        label = 1.0
118
                    else:
119
                        label = 2.0
120
121
                    train_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]])
122
                    train_label_array.append(label)
123
        except:
124
            hrv_module_error = 1
125
with open('train_input.pickle','wb') as f: 
126
    pickle.dump(train_input_array, f)
127
with open('train_label.pickle','wb') as f: 
128
    pickle.dump(train_label_array, f)
129
130
131
val_input_array = []
132
val_label_array = []
133
for data_index in range(len(val_data_name)):
134
    print (val_data_name[data_index])
135
    win_num = len(wfdb.rdann(os.path.join(data_path,val_data_name[data_index]), 'apn').symbol)
136
    signals, fields = wfdb.rdsamp(os.path.join(data_path,val_data_name[data_index]))
137
    for index in tqdm(range(1, win_num)):
138
        samp_from = index * 60 * FS # 60 seconds
139
        samp_to = samp_from + 60 * FS  # 60 seconds
140
141
        qrs_ann = wfdb.rdann(data_path + val_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample
142
        apn_ann = wfdb.rdann(data_path + val_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol
143
144
        qrs_amp = get_qrs_amp(signals, qrs_ann)
145
146
        rri = np.diff(qrs_ann)
147
        rri_ms = rri.astype('float') / FS * 1000.0
148
        try:
149
            rri_filt = moving_median(rri_ms)
150
151
            if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI):
152
                time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP)
153
                qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP)
154
                rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
155
                qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))]
156
                #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
157
158
                if len(rri_intp) != (FS_INTP * 60):
159
                    skip = 1
160
                else:
161
                    skip = 0
162
163
                if skip == 0:
164
                    rri_intp = rri_intp - np.mean(rri_intp)
165
                    qrs_intp = qrs_intp - np.mean(qrs_intp)
166
                    if apn_ann[0] == 'N': # Normal
167
                        label = 0.0
168
                    elif apn_ann[0] == 'A': # Apnea
169
                        label = 1.0
170
                    else:
171
                        label = 2.0
172
173
                    val_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]])
174
                    val_label_array.append(label)
175
        except:
176
            hrv_module_error = 1
177
178
with open('val_input.pickle','wb') as f: 
179
    pickle.dump(val_input_array, f)
180
with open('val_label.pickle','wb') as f: 
181
    pickle.dump(val_label_array, f)
182
183
test_input_array = []
184
test_label_array = []
185
for data_index in range(len(test_data_name)):
186
    print (test_data_name[data_index])
187
    win_num = len(wfdb.rdann(os.path.join(data_path,test_data_name[data_index]), 'apn').symbol)
188
    signals, fields = wfdb.rdsamp(os.path.join(data_path,test_data_name[data_index]))
189
    for index in tqdm(range(1, win_num)):
190
        samp_from = index * 60 * FS # 60 seconds
191
        samp_to = samp_from + 60 * FS  # 60 seconds
192
193
        qrs_ann = wfdb.rdann(data_path + test_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample
194
        apn_ann = wfdb.rdann(data_path + test_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol
195
196
        qrs_amp = get_qrs_amp(signals, qrs_ann)
197
198
        rri = np.diff(qrs_ann)
199
        rri_ms = rri.astype('float') / FS * 1000.0
200
        try:
201
            rri_filt = moving_median(rri_ms)
202
203
            if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI):
204
                time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP)
205
                qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP)
206
                rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
207
                qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))]
208
                #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
209
210
                if len(rri_intp) != (FS_INTP * 60):
211
                    skip = 1
212
                else:
213
                    skip = 0
214
215
                if skip == 0:
216
                    rri_intp = rri_intp - np.mean(rri_intp)
217
                    qrs_intp = qrs_intp - np.mean(qrs_intp)
218
                    if apn_ann[0] == 'N': # Normal
219
                        label = 0.0
220
                    elif apn_ann[0] == 'A': # Apnea
221
                        label = 1.0
222
                    else:
223
                        label = 2.0
224
225
                    test_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]])
226
                    test_label_array.append(label)
227
        except:
228
            hrv_module_error = 1
229
230
with open('test_input.pickle','wb') as f: 
231
    pickle.dump(test_input_array, f)
232
with open('test_label.pickle','wb') as f: 
233
    pickle.dump(test_label_array, f)