Diff of /src/data.py [000000] .. [a378de]

Switch to side-by-side view

--- a
+++ b/src/data.py
@@ -0,0 +1,142 @@
+"""
+The data is provided by 
+https://physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm
+
+The recordings were digitized at 360 samples per second per channel with 11-bit resolution over a 10 mV range.
+Two or more cardiologists independently annotated each record; disagreements were resolved to obtain the computer-readable
+reference annotations for each beat (approximately 110,000 annotations in all) included with the database.
+
+    Code		Description
+    N		Normal beat (displayed as . by the PhysioBank ATM, LightWAVE, pschart, and psfd)
+    L		Left bundle branch block beat
+    R		Right bundle branch block beat
+    B		Bundle branch block beat (unspecified)
+    A		Atrial premature beat
+    a		Aberrated atrial premature beat
+    J		Nodal (junctional) premature beat
+    S		Supraventricular premature or ectopic beat (atrial or nodal)
+    V		Premature ventricular contraction
+    r		R-on-T premature ventricular contraction
+    F		Fusion of ventricular and normal beat
+    e		Atrial escape beat
+    j		Nodal (junctional) escape beat
+    n		Supraventricular escape beat (atrial or nodal)
+    E		Ventricular escape beat
+    /		Paced beat
+    f		Fusion of paced and normal beat
+    Q		Unclassifiable beat
+    ?		Beat not classified during learning
+"""
+
+from __future__ import division, print_function
+import os
+from tqdm import tqdm
+import numpy as np
+import random
+import h5py
+from utils import *
+from config import get_config
+
+def preprocess( split ):
+    nums = ['100','101','102','103','104','105','106','107','108','109','111','112','113','114','115','116','117','118','119','121','122','123','124','200','201','202','203','205','207','208','209','210','212','213','214','215','217','219','220','221','222','223','228','230','231','232','233','234']
+    features = ['MLII', 'V1', 'V2', 'V4', 'V5'] 
+
+    if split :
+        testset = ['101', '105','114','118', '124', '201', '210' , '217']
+        trainset = [x for x in nums if x not in testset]
+
+    def dataSaver(dataSet, datasetname, labelsname):
+        classes = ['N','V','/','A','F','~']#,'L','R',f','j','E','a']#,'J','Q','e','S']
+        Nclass = len(classes)
+        datadict, datalabel= dict(), dict()
+
+        for feature in features:
+            datadict[feature] = list()
+            datalabel[feature] = list()
+
+        def dataprocess():
+          input_size = config.input_size 
+          for num in tqdm(dataSet):
+            from wfdb import rdrecord, rdann
+            record = rdrecord('dataset/'+ num, smooth_frames= True)
+            from sklearn import preprocessing
+            signals0 = preprocessing.scale(np.nan_to_num(record.p_signal[:,0])).tolist()
+            signals1 = preprocessing.scale(np.nan_to_num(record.p_signal[:,1])).tolist()
+            from scipy.signal import find_peaks
+            peaks, _ = find_peaks(signals0, distance=150)
+
+            feature0, feature1 = record.sig_name[0], record.sig_name[1]
+
+            global lppened0, lappend1, dappend0, dappend1 
+            lappend0 = datalabel[feature0].append
+            lappend1 = datalabel[feature1].append
+            dappend0 = datadict[feature0].append
+            dappend1 = datadict[feature1].append
+            # skip a first peak to have enough range of the sample 
+            for peak in tqdm(peaks[1:-1]):
+              start, end =  peak-input_size//2 , peak+input_size//2
+              ann = rdann('dataset/'+ num, extension='atr', sampfrom = start, sampto = end, return_label_elements=['symbol'])
+              
+              def to_dict(chosenSym):
+                y = [0]*Nclass
+                y[classes.index(chosenSym)] = 1
+                lappend0(y)
+                lappend1(y)
+                dappend0(signals0[start:end])
+                dappend1(signals1[start:end])
+
+              annSymbol = ann.symbol
+              # remove some of "N" which breaks the balance of dataset 
+              if len(annSymbol) == 1 and (annSymbol[0] in classes) and (annSymbol[0] != "N" or np.random.random()<0.15):
+                to_dict(annSymbol[0])
+        print("processing data...")
+        dataprocess()
+        noises = add_noise(config)
+        for feature in ["MLII", "V1"]: 
+            d = np.array(datadict[feature])
+            if len(d) > 15*10**3:
+                n = np.array(noises["trainset"])
+            else:
+                n = np.array(noises["testset"]) 
+            datadict[feature]=np.concatenate((d,n))
+            size, _  = n.shape 
+            l = np.array(datalabel[feature])
+            noise_label = [0]*Nclass
+            noise_label[-1] = 1
+            
+            noise_label = np.array([noise_label] * size) 
+            datalabel[feature] = np.concatenate((l, noise_label))
+
+        with h5py.File(datasetname, 'w') as f:
+            for key, data in datadict.items():
+                f.create_dataset(key, data=data)
+        with h5py.File(labelsname, 'w') as f:
+            for key, data in datalabel.items():
+                f.create_dataset(key, data=data)        
+
+    if split:
+        dataSaver(trainset, 'dataset/train.keras', 'dataset/trainlabel.keras')
+        dataSaver(testset, 'dataset/test.keras', 'dataset/testlabel.keras')
+    else:
+        dataSaver(nums, 'dataset/targetdata.keras', 'dataset/labeldata.keras')
+
+def main(config):
+    def Downloadmitdb():
+        ext = ['dat', 'hea', 'atr']
+        nums = ['100','101','102','103','104','105','106','107','108','109','111','112','113','114','115','116','117','118','119','121','122','123','124','200','201','202','203','205','207','208','209','210','212','213','214','215','217','219','220','221','222','223','228','230','231','232','233','234']
+        for num in tqdm(nums):
+            for e in ext:
+                url = "https://physionet.org/physiobank/database/mitdb/"
+                url = url + num +"."+e
+                mkdir_recursive('dataset')
+                cmd = "cd dataset && curl -O "+url
+                os.system(cmd)
+
+    if config.downloading:
+        Downloadmitdb()
+        #print("do not download")
+    return preprocess(config.split)
+
+if __name__=="__main__":
+    config = get_config()
+    main(config)