|
a |
|
b/src/data.py |
|
|
1 |
""" |
|
|
2 |
The data is provided by |
|
|
3 |
https://physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm |
|
|
4 |
|
|
|
5 |
The recordings were digitized at 360 samples per second per channel with 11-bit resolution over a 10 mV range. |
|
|
6 |
Two or more cardiologists independently annotated each record; disagreements were resolved to obtain the computer-readable |
|
|
7 |
reference annotations for each beat (approximately 110,000 annotations in all) included with the database. |
|
|
8 |
|
|
|
9 |
Code Description |
|
|
10 |
N Normal beat (displayed as . by the PhysioBank ATM, LightWAVE, pschart, and psfd) |
|
|
11 |
L Left bundle branch block beat |
|
|
12 |
R Right bundle branch block beat |
|
|
13 |
B Bundle branch block beat (unspecified) |
|
|
14 |
A Atrial premature beat |
|
|
15 |
a Aberrated atrial premature beat |
|
|
16 |
J Nodal (junctional) premature beat |
|
|
17 |
S Supraventricular premature or ectopic beat (atrial or nodal) |
|
|
18 |
V Premature ventricular contraction |
|
|
19 |
r R-on-T premature ventricular contraction |
|
|
20 |
F Fusion of ventricular and normal beat |
|
|
21 |
e Atrial escape beat |
|
|
22 |
j Nodal (junctional) escape beat |
|
|
23 |
n Supraventricular escape beat (atrial or nodal) |
|
|
24 |
E Ventricular escape beat |
|
|
25 |
/ Paced beat |
|
|
26 |
f Fusion of paced and normal beat |
|
|
27 |
Q Unclassifiable beat |
|
|
28 |
? Beat not classified during learning |
|
|
29 |
""" |
|
|
30 |
|
|
|
31 |
from __future__ import division, print_function |
|
|
32 |
import os |
|
|
33 |
from tqdm import tqdm |
|
|
34 |
import numpy as np |
|
|
35 |
import random |
|
|
36 |
import h5py |
|
|
37 |
from utils import * |
|
|
38 |
from config import get_config |
|
|
39 |
|
|
|
40 |
def preprocess( split ): |
|
|
41 |
nums = ['100','101','102','103','104','105','106','107','108','109','111','112','113','114','115','116','117','118','119','121','122','123','124','200','201','202','203','205','207','208','209','210','212','213','214','215','217','219','220','221','222','223','228','230','231','232','233','234'] |
|
|
42 |
features = ['MLII', 'V1', 'V2', 'V4', 'V5'] |
|
|
43 |
|
|
|
44 |
if split : |
|
|
45 |
testset = ['101', '105','114','118', '124', '201', '210' , '217'] |
|
|
46 |
trainset = [x for x in nums if x not in testset] |
|
|
47 |
|
|
|
48 |
def dataSaver(dataSet, datasetname, labelsname): |
|
|
49 |
classes = ['N','V','/','A','F','~']#,'L','R',f','j','E','a']#,'J','Q','e','S'] |
|
|
50 |
Nclass = len(classes) |
|
|
51 |
datadict, datalabel= dict(), dict() |
|
|
52 |
|
|
|
53 |
for feature in features: |
|
|
54 |
datadict[feature] = list() |
|
|
55 |
datalabel[feature] = list() |
|
|
56 |
|
|
|
57 |
def dataprocess(): |
|
|
58 |
input_size = config.input_size |
|
|
59 |
for num in tqdm(dataSet): |
|
|
60 |
from wfdb import rdrecord, rdann |
|
|
61 |
record = rdrecord('dataset/'+ num, smooth_frames= True) |
|
|
62 |
from sklearn import preprocessing |
|
|
63 |
signals0 = preprocessing.scale(np.nan_to_num(record.p_signal[:,0])).tolist() |
|
|
64 |
signals1 = preprocessing.scale(np.nan_to_num(record.p_signal[:,1])).tolist() |
|
|
65 |
from scipy.signal import find_peaks |
|
|
66 |
peaks, _ = find_peaks(signals0, distance=150) |
|
|
67 |
|
|
|
68 |
feature0, feature1 = record.sig_name[0], record.sig_name[1] |
|
|
69 |
|
|
|
70 |
global lppened0, lappend1, dappend0, dappend1 |
|
|
71 |
lappend0 = datalabel[feature0].append |
|
|
72 |
lappend1 = datalabel[feature1].append |
|
|
73 |
dappend0 = datadict[feature0].append |
|
|
74 |
dappend1 = datadict[feature1].append |
|
|
75 |
# skip a first peak to have enough range of the sample |
|
|
76 |
for peak in tqdm(peaks[1:-1]): |
|
|
77 |
start, end = peak-input_size//2 , peak+input_size//2 |
|
|
78 |
ann = rdann('dataset/'+ num, extension='atr', sampfrom = start, sampto = end, return_label_elements=['symbol']) |
|
|
79 |
|
|
|
80 |
def to_dict(chosenSym): |
|
|
81 |
y = [0]*Nclass |
|
|
82 |
y[classes.index(chosenSym)] = 1 |
|
|
83 |
lappend0(y) |
|
|
84 |
lappend1(y) |
|
|
85 |
dappend0(signals0[start:end]) |
|
|
86 |
dappend1(signals1[start:end]) |
|
|
87 |
|
|
|
88 |
annSymbol = ann.symbol |
|
|
89 |
# remove some of "N" which breaks the balance of dataset |
|
|
90 |
if len(annSymbol) == 1 and (annSymbol[0] in classes) and (annSymbol[0] != "N" or np.random.random()<0.15): |
|
|
91 |
to_dict(annSymbol[0]) |
|
|
92 |
print("processing data...") |
|
|
93 |
dataprocess() |
|
|
94 |
noises = add_noise(config) |
|
|
95 |
for feature in ["MLII", "V1"]: |
|
|
96 |
d = np.array(datadict[feature]) |
|
|
97 |
if len(d) > 15*10**3: |
|
|
98 |
n = np.array(noises["trainset"]) |
|
|
99 |
else: |
|
|
100 |
n = np.array(noises["testset"]) |
|
|
101 |
datadict[feature]=np.concatenate((d,n)) |
|
|
102 |
size, _ = n.shape |
|
|
103 |
l = np.array(datalabel[feature]) |
|
|
104 |
noise_label = [0]*Nclass |
|
|
105 |
noise_label[-1] = 1 |
|
|
106 |
|
|
|
107 |
noise_label = np.array([noise_label] * size) |
|
|
108 |
datalabel[feature] = np.concatenate((l, noise_label)) |
|
|
109 |
|
|
|
110 |
with h5py.File(datasetname, 'w') as f: |
|
|
111 |
for key, data in datadict.items(): |
|
|
112 |
f.create_dataset(key, data=data) |
|
|
113 |
with h5py.File(labelsname, 'w') as f: |
|
|
114 |
for key, data in datalabel.items(): |
|
|
115 |
f.create_dataset(key, data=data) |
|
|
116 |
|
|
|
117 |
if split: |
|
|
118 |
dataSaver(trainset, 'dataset/train.keras', 'dataset/trainlabel.keras') |
|
|
119 |
dataSaver(testset, 'dataset/test.keras', 'dataset/testlabel.keras') |
|
|
120 |
else: |
|
|
121 |
dataSaver(nums, 'dataset/targetdata.keras', 'dataset/labeldata.keras') |
|
|
122 |
|
|
|
123 |
def main(config): |
|
|
124 |
def Downloadmitdb(): |
|
|
125 |
ext = ['dat', 'hea', 'atr'] |
|
|
126 |
nums = ['100','101','102','103','104','105','106','107','108','109','111','112','113','114','115','116','117','118','119','121','122','123','124','200','201','202','203','205','207','208','209','210','212','213','214','215','217','219','220','221','222','223','228','230','231','232','233','234'] |
|
|
127 |
for num in tqdm(nums): |
|
|
128 |
for e in ext: |
|
|
129 |
url = "https://physionet.org/physiobank/database/mitdb/" |
|
|
130 |
url = url + num +"."+e |
|
|
131 |
mkdir_recursive('dataset') |
|
|
132 |
cmd = "cd dataset && curl -O "+url |
|
|
133 |
os.system(cmd) |
|
|
134 |
|
|
|
135 |
if config.downloading: |
|
|
136 |
Downloadmitdb() |
|
|
137 |
#print("do not download") |
|
|
138 |
return preprocess(config.split) |
|
|
139 |
|
|
|
140 |
if __name__=="__main__": |
|
|
141 |
config = get_config() |
|
|
142 |
main(config) |