Diff of /src/data.py [000000] .. [a378de]

Switch to unified view

a b/src/data.py
1
"""
2
The data is provided by 
3
https://physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm
4
5
The recordings were digitized at 360 samples per second per channel with 11-bit resolution over a 10 mV range.
6
Two or more cardiologists independently annotated each record; disagreements were resolved to obtain the computer-readable
7
reference annotations for each beat (approximately 110,000 annotations in all) included with the database.
8
9
    Code        Description
10
    N       Normal beat (displayed as . by the PhysioBank ATM, LightWAVE, pschart, and psfd)
11
    L       Left bundle branch block beat
12
    R       Right bundle branch block beat
13
    B       Bundle branch block beat (unspecified)
14
    A       Atrial premature beat
15
    a       Aberrated atrial premature beat
16
    J       Nodal (junctional) premature beat
17
    S       Supraventricular premature or ectopic beat (atrial or nodal)
18
    V       Premature ventricular contraction
19
    r       R-on-T premature ventricular contraction
20
    F       Fusion of ventricular and normal beat
21
    e       Atrial escape beat
22
    j       Nodal (junctional) escape beat
23
    n       Supraventricular escape beat (atrial or nodal)
24
    E       Ventricular escape beat
25
    /       Paced beat
26
    f       Fusion of paced and normal beat
27
    Q       Unclassifiable beat
28
    ?       Beat not classified during learning
29
"""
30
31
from __future__ import division, print_function
32
import os
33
from tqdm import tqdm
34
import numpy as np
35
import random
36
import h5py
37
from utils import *
38
from config import get_config
39
40
def preprocess( split ):
41
    nums = ['100','101','102','103','104','105','106','107','108','109','111','112','113','114','115','116','117','118','119','121','122','123','124','200','201','202','203','205','207','208','209','210','212','213','214','215','217','219','220','221','222','223','228','230','231','232','233','234']
42
    features = ['MLII', 'V1', 'V2', 'V4', 'V5'] 
43
44
    if split :
45
        testset = ['101', '105','114','118', '124', '201', '210' , '217']
46
        trainset = [x for x in nums if x not in testset]
47
48
    def dataSaver(dataSet, datasetname, labelsname):
49
        classes = ['N','V','/','A','F','~']#,'L','R',f','j','E','a']#,'J','Q','e','S']
50
        Nclass = len(classes)
51
        datadict, datalabel= dict(), dict()
52
53
        for feature in features:
54
            datadict[feature] = list()
55
            datalabel[feature] = list()
56
57
        def dataprocess():
58
          input_size = config.input_size 
59
          for num in tqdm(dataSet):
60
            from wfdb import rdrecord, rdann
61
            record = rdrecord('dataset/'+ num, smooth_frames= True)
62
            from sklearn import preprocessing
63
            signals0 = preprocessing.scale(np.nan_to_num(record.p_signal[:,0])).tolist()
64
            signals1 = preprocessing.scale(np.nan_to_num(record.p_signal[:,1])).tolist()
65
            from scipy.signal import find_peaks
66
            peaks, _ = find_peaks(signals0, distance=150)
67
68
            feature0, feature1 = record.sig_name[0], record.sig_name[1]
69
70
            global lppened0, lappend1, dappend0, dappend1 
71
            lappend0 = datalabel[feature0].append
72
            lappend1 = datalabel[feature1].append
73
            dappend0 = datadict[feature0].append
74
            dappend1 = datadict[feature1].append
75
            # skip a first peak to have enough range of the sample 
76
            for peak in tqdm(peaks[1:-1]):
77
              start, end =  peak-input_size//2 , peak+input_size//2
78
              ann = rdann('dataset/'+ num, extension='atr', sampfrom = start, sampto = end, return_label_elements=['symbol'])
79
              
80
              def to_dict(chosenSym):
81
                y = [0]*Nclass
82
                y[classes.index(chosenSym)] = 1
83
                lappend0(y)
84
                lappend1(y)
85
                dappend0(signals0[start:end])
86
                dappend1(signals1[start:end])
87
88
              annSymbol = ann.symbol
89
              # remove some of "N" which breaks the balance of dataset 
90
              if len(annSymbol) == 1 and (annSymbol[0] in classes) and (annSymbol[0] != "N" or np.random.random()<0.15):
91
                to_dict(annSymbol[0])
92
        print("processing data...")
93
        dataprocess()
94
        noises = add_noise(config)
95
        for feature in ["MLII", "V1"]: 
96
            d = np.array(datadict[feature])
97
            if len(d) > 15*10**3:
98
                n = np.array(noises["trainset"])
99
            else:
100
                n = np.array(noises["testset"]) 
101
            datadict[feature]=np.concatenate((d,n))
102
            size, _  = n.shape 
103
            l = np.array(datalabel[feature])
104
            noise_label = [0]*Nclass
105
            noise_label[-1] = 1
106
            
107
            noise_label = np.array([noise_label] * size) 
108
            datalabel[feature] = np.concatenate((l, noise_label))
109
110
        with h5py.File(datasetname, 'w') as f:
111
            for key, data in datadict.items():
112
                f.create_dataset(key, data=data)
113
        with h5py.File(labelsname, 'w') as f:
114
            for key, data in datalabel.items():
115
                f.create_dataset(key, data=data)        
116
117
    if split:
118
        dataSaver(trainset, 'dataset/train.keras', 'dataset/trainlabel.keras')
119
        dataSaver(testset, 'dataset/test.keras', 'dataset/testlabel.keras')
120
    else:
121
        dataSaver(nums, 'dataset/targetdata.keras', 'dataset/labeldata.keras')
122
123
def main(config):
124
    def Downloadmitdb():
125
        ext = ['dat', 'hea', 'atr']
126
        nums = ['100','101','102','103','104','105','106','107','108','109','111','112','113','114','115','116','117','118','119','121','122','123','124','200','201','202','203','205','207','208','209','210','212','213','214','215','217','219','220','221','222','223','228','230','231','232','233','234']
127
        for num in tqdm(nums):
128
            for e in ext:
129
                url = "https://physionet.org/physiobank/database/mitdb/"
130
                url = url + num +"."+e
131
                mkdir_recursive('dataset')
132
                cmd = "cd dataset && curl -O "+url
133
                os.system(cmd)
134
135
    if config.downloading:
136
        Downloadmitdb()
137
        #print("do not download")
138
    return preprocess(config.split)
139
140
if __name__=="__main__":
141
    config = get_config()
142
    main(config)