--- a
+++ b/pre_proc.py
@@ -0,0 +1,233 @@
+import wfdb
+import matplotlib.pyplot as plt
+import numpy as np
+from hrv.filters import quotient, moving_median
+from scipy import interpolate
+from tqdm import tqdm
+import pickle
+import os
+FS = 100.0
+
+# From https://github.com/rhenanbartels/hrv/blob/develop/hrv/classical.py
+def create_time_info(rri):
+    rri_time = np.cumsum(rri) / 1000.0  # make it seconds
+    return rri_time - rri_time[0]   # force it to start at zero
+
+def create_interp_time(rri, fs):
+    time_rri = create_time_info(rri)
+    return np.arange(0, time_rri[-1], 1 / float(fs))
+
+def interp_cubic_spline(rri, fs):
+    time_rri = create_time_info(rri)
+    time_rri_interp = create_interp_time(rri, fs)
+    tck = interpolate.splrep(time_rri, rri, s=0)
+    rri_interp = interpolate.splev(time_rri_interp, tck, der=0)
+    return time_rri_interp, rri_interp
+
+def interp_cubic_spline_qrs(qrs_index, qrs_amp, fs):
+    time_qrs = qrs_index / float(FS)
+    time_qrs = time_qrs - time_qrs[0]
+    time_qrs_interp = np.arange(0, time_qrs[-1], 1/float(fs))
+    tck = interpolate.splrep(time_qrs, qrs_amp, s=0)
+    qrs_interp = interpolate.splev(time_qrs_interp, tck, der=0)
+    return time_qrs_interp, qrs_interp
+
+data_path = './data/'
+train_data_name = ['a02', 'a03', 'a04', 'a05',
+             'a06', 'a07', 'a08', 'a09', 'a10',
+             'a11', 'a12', 'a13', 'a14', 'a15',
+             'a16', 'a17', 'a18', 'a19',
+             'b02', 'b03', 'b04',
+             'c02', 'c03', 'c04', 'c05',
+             'c06', 'c07', 'c08', 'c09',
+             ]
+val_data_name = ['a01', 'b01', 'c01']
+test_data_name = ['a20','b05','c10']
+age = [51, 38, 54, 52, 58,
+       63, 44, 51, 52, 58,
+       58, 52, 51, 51, 60,
+       44, 40, 52, 55, 58,
+       44, 53, 53, 42, 52,
+       31, 37, 39, 41, 28,
+       28, 30, 42, 37, 27]
+sex = [1, 1, 1, 1, 1,
+       1, 1, 1, 1, 1,
+       1, 1, 1, 1, 1,
+       1, 1, 1, 1, 1,
+       0, 1, 1, 1, 1,
+       1, 1, 1, 0, 0,
+       0, 0, 1, 1, 1]
+
+
+def get_qrs_amp(ecg, qrs):
+    interval = int(FS * 0.250)
+    qrs_amp = []
+    for index in range(len(qrs)):
+        curr_qrs = qrs[index]
+        amp = np.max(ecg[curr_qrs-interval:curr_qrs+interval])
+        qrs_amp.append(amp)
+
+    return qrs_amp
+
+MARGIN = 10
+FS_INTP = 4
+MAX_HR = 300.0
+MIN_HR = 20.0
+MIN_RRI = 1.0 / (MAX_HR / 60.0) * 1000
+MAX_RRI = 1.0 / (MIN_HR / 60.0) * 1000
+train_input_array = []
+train_label_array = []
+
+for data_index in range(len(train_data_name)):
+    print (train_data_name[data_index])
+    win_num = len(wfdb.rdann(os.path.join(data_path,train_data_name[data_index]), 'apn').symbol)
+    signals, fields = wfdb.rdsamp(os.path.join(data_path,train_data_name[data_index]))
+    for index in tqdm(range(1, win_num)):
+        samp_from = index * 60 * FS # 60 seconds
+        samp_to = samp_from + 60 * FS  # 60 seconds
+
+        qrs_ann = wfdb.rdann(data_path + train_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample
+        apn_ann = wfdb.rdann(data_path + train_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol
+
+        qrs_amp = get_qrs_amp(signals, qrs_ann)
+
+        rri = np.diff(qrs_ann)
+        rri_ms = rri.astype('float') / FS * 1000.0
+        try:
+            rri_filt = moving_median(rri_ms)
+
+            if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI):
+                time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP)
+                qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP)
+                rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
+                qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))]
+                #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
+
+                if len(rri_intp) != (FS_INTP * 60):
+                    skip = 1
+                else:
+                    skip = 0
+
+                if skip == 0:
+                    rri_intp = rri_intp - np.mean(rri_intp)
+                    qrs_intp = qrs_intp - np.mean(qrs_intp)
+                    if apn_ann[0] == 'N': # Normal
+                        label = 0.0
+                    elif apn_ann[0] == 'A': # Apnea
+                        label = 1.0
+                    else:
+                        label = 2.0
+
+                    train_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]])
+                    train_label_array.append(label)
+        except:
+            hrv_module_error = 1
+with open('train_input.pickle','wb') as f: 
+    pickle.dump(train_input_array, f)
+with open('train_label.pickle','wb') as f: 
+    pickle.dump(train_label_array, f)
+
+
+val_input_array = []
+val_label_array = []
+for data_index in range(len(val_data_name)):
+    print (val_data_name[data_index])
+    win_num = len(wfdb.rdann(os.path.join(data_path,val_data_name[data_index]), 'apn').symbol)
+    signals, fields = wfdb.rdsamp(os.path.join(data_path,val_data_name[data_index]))
+    for index in tqdm(range(1, win_num)):
+        samp_from = index * 60 * FS # 60 seconds
+        samp_to = samp_from + 60 * FS  # 60 seconds
+
+        qrs_ann = wfdb.rdann(data_path + val_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample
+        apn_ann = wfdb.rdann(data_path + val_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol
+
+        qrs_amp = get_qrs_amp(signals, qrs_ann)
+
+        rri = np.diff(qrs_ann)
+        rri_ms = rri.astype('float') / FS * 1000.0
+        try:
+            rri_filt = moving_median(rri_ms)
+
+            if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI):
+                time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP)
+                qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP)
+                rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
+                qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))]
+                #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
+
+                if len(rri_intp) != (FS_INTP * 60):
+                    skip = 1
+                else:
+                    skip = 0
+
+                if skip == 0:
+                    rri_intp = rri_intp - np.mean(rri_intp)
+                    qrs_intp = qrs_intp - np.mean(qrs_intp)
+                    if apn_ann[0] == 'N': # Normal
+                        label = 0.0
+                    elif apn_ann[0] == 'A': # Apnea
+                        label = 1.0
+                    else:
+                        label = 2.0
+
+                    val_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]])
+                    val_label_array.append(label)
+        except:
+            hrv_module_error = 1
+
+with open('val_input.pickle','wb') as f: 
+    pickle.dump(val_input_array, f)
+with open('val_label.pickle','wb') as f: 
+    pickle.dump(val_label_array, f)
+
+test_input_array = []
+test_label_array = []
+for data_index in range(len(test_data_name)):
+    print (test_data_name[data_index])
+    win_num = len(wfdb.rdann(os.path.join(data_path,test_data_name[data_index]), 'apn').symbol)
+    signals, fields = wfdb.rdsamp(os.path.join(data_path,test_data_name[data_index]))
+    for index in tqdm(range(1, win_num)):
+        samp_from = index * 60 * FS # 60 seconds
+        samp_to = samp_from + 60 * FS  # 60 seconds
+
+        qrs_ann = wfdb.rdann(data_path + test_data_name[data_index], 'qrs', sampfrom=samp_from - (MARGIN*100), sampto=samp_to + (MARGIN*100)).sample
+        apn_ann = wfdb.rdann(data_path + test_data_name[data_index], 'apn', sampfrom=samp_from, sampto=samp_to-1).symbol
+
+        qrs_amp = get_qrs_amp(signals, qrs_ann)
+
+        rri = np.diff(qrs_ann)
+        rri_ms = rri.astype('float') / FS * 1000.0
+        try:
+            rri_filt = moving_median(rri_ms)
+
+            if len(rri_filt) > 5 and (np.min(rri_filt) >= MIN_RRI and np.max(rri_filt) <= MAX_RRI):
+                time_intp, rri_intp = interp_cubic_spline(rri_filt, FS_INTP)
+                qrs_time_intp, qrs_intp = interp_cubic_spline_qrs(qrs_ann, qrs_amp, FS_INTP)
+                rri_intp = rri_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
+                qrs_intp = qrs_intp[(qrs_time_intp >= MARGIN) & (qrs_time_intp < (60 + MARGIN))]
+                #time_intp = time_intp[(time_intp >= MARGIN) & (time_intp < (60+MARGIN))]
+
+                if len(rri_intp) != (FS_INTP * 60):
+                    skip = 1
+                else:
+                    skip = 0
+
+                if skip == 0:
+                    rri_intp = rri_intp - np.mean(rri_intp)
+                    qrs_intp = qrs_intp - np.mean(qrs_intp)
+                    if apn_ann[0] == 'N': # Normal
+                        label = 0.0
+                    elif apn_ann[0] == 'A': # Apnea
+                        label = 1.0
+                    else:
+                        label = 2.0
+
+                    test_input_array.append([rri_intp, qrs_intp, age[data_index], sex[data_index]])
+                    test_label_array.append(label)
+        except:
+            hrv_module_error = 1
+
+with open('test_input.pickle','wb') as f: 
+    pickle.dump(test_input_array, f)
+with open('test_label.pickle','wb') as f: 
+    pickle.dump(test_label_array, f)
\ No newline at end of file