--- a
+++ b/utils.py
@@ -0,0 +1,142 @@
+import wfdb
+import pywt
+import seaborn
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import confusion_matrix
+from sklearn.model_selection import train_test_split
+
+
+# wavelet denoise preprocess using mallat algorithm
+def denoise(data):
+    # wavelet decomposition
+    coeffs = pywt.wavedec(data=data, wavelet='db5', level=9)
+    cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs
+
+    # denoise using soft threshold
+    threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1))))
+    cD1.fill(0)
+    cD2.fill(0)
+    for i in range(1, len(coeffs) - 2):
+        coeffs[i] = pywt.threshold(coeffs[i], threshold)
+
+    # get the denoised signal by inverse wavelet transform
+    rdata = pywt.waverec(coeffs=coeffs, wavelet='db5')
+    return rdata
+
+
+# load the ecg data and the corresponding labels, then denoise the data using wavelet transform
+def get_data_set(number, X_data, Y_data):
+    ecgClassSet = ['N', 'A', 'V', 'L', 'R']
+
+    # load the ecg data record
+    print("loading the ecg data of No." + number)
+    record = wfdb.rdrecord('ecg_data/' + number, channel_names=['MLII'])
+    data = record.p_signal.flatten()
+    rdata = denoise(data=data)
+
+    # get the positions of R-wave and the corresponding labels
+    annotation = wfdb.rdann('ecg_data/' + number, 'atr')
+    Rlocation = annotation.sample
+    Rclass = annotation.symbol
+
+    # remove the unstable data at the beginning and the end
+    start = 10
+    end = 5
+    i = start
+    j = len(annotation.symbol) - end
+
+    # the data with specific labels (N/A/V/L/R) required in this record are selected, and the others are discarded
+    # X_data: data points of length 300 around the R-wave
+    # Y_data: convert N/A/V/L/R to 0/1/2/3/4 in order
+    while i < j:
+        try:
+            lable = ecgClassSet.index(Rclass[i])
+            x_train = rdata[Rlocation[i] - 99:Rlocation[i] + 201]
+            X_data.append(x_train)
+            Y_data.append(lable)
+            i += 1
+        except ValueError:
+            i += 1
+    return
+
+
+# load dataset and preprocess
+def load_data(ratio, random_seed):
+    numberSet = ['100', '101', '103', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115',
+                 '116', '117', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '208',
+                 '210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230',
+                 '231', '232', '233', '234']
+    dataSet = []
+    lableSet = []
+    for n in numberSet:
+        get_data_set(n, dataSet, lableSet)
+
+    # reshape the data and split the dataset
+    dataSet = np.array(dataSet).reshape(-1, 300)
+    lableSet = np.array(lableSet).reshape(-1)
+    X_train, X_test, y_train, y_test = train_test_split(dataSet, lableSet, test_size=ratio, random_state=random_seed)
+    return X_train, X_test, y_train, y_test
+
+
+# confusion matrix
+def plot_heat_map(y_test, y_pred):
+    con_mat = confusion_matrix(y_test, y_pred)
+    # normalize
+    # con_mat_norm = con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis]
+    # con_mat_norm = np.around(con_mat_norm, decimals=2)
+
+    # plot
+    plt.figure(figsize=(8, 8))
+    seaborn.heatmap(con_mat, annot=True, fmt='.20g', cmap='Blues')
+    plt.ylim(0, 5)
+    plt.xlabel('Predicted labels')
+    plt.ylabel('True labels')
+    plt.title('Confusion Matrix')
+    plt.savefig('confusion_matrix.png')
+    plt.show()
+
+
+def plot_history_tf(history):
+    plt.figure(figsize=(8, 8))
+    plt.plot(history.history['accuracy'])
+    plt.plot(history.history['val_accuracy'])
+    plt.title('Model Accuracy')
+    plt.ylabel('Accuracy')
+    plt.xlabel('Epoch')
+    plt.legend(['Train', 'Test'], loc='upper left')
+    plt.savefig('accuracy.png')
+    plt.show()
+
+    plt.figure(figsize=(8, 8))
+    plt.plot(history.history['loss'])
+    plt.plot(history.history['val_loss'])
+    plt.title('Model Loss')
+    plt.ylabel('Loss')
+    plt.xlabel('Epoch')
+    plt.legend(['Train', 'Test'], loc='upper left')
+    plt.savefig('loss.png')
+    plt.show()
+
+
+def plot_history_torch(history):
+    plt.figure(figsize=(8, 8))
+    plt.plot(history['train_acc'])
+    plt.plot(history['test_acc'])
+    plt.title('Model Accuracy')
+    plt.ylabel('Accuracy')
+    plt.xlabel('Epoch')
+    plt.legend(['Train', 'Test'], loc='upper left')
+    plt.savefig('accuracy.png')
+    plt.show()
+
+    plt.figure(figsize=(8, 8))
+    plt.plot(history['train_loss'])
+    plt.plot(history['test_loss'])
+    plt.title('Model Loss')
+    plt.ylabel('Loss')
+    plt.xlabel('Epoch')
+    plt.legend(['Train', 'Test'], loc='upper left')
+    plt.savefig('loss.png')
+    plt.show()
+