Switch to side-by-side view

--- a
+++ b/tensorflow_impl/cnn_tf2.py
@@ -0,0 +1,241 @@
+import time
+import argparse
+
+import tensorflow as tf
+import numpy as np
+
+from tensorflow.keras.layers import Dense, Flatten, Conv1D, BatchNormalization, MaxPool1D, Dropout
+from tensorflow.keras.metrics import CategoricalAccuracy
+
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import precision_score, recall_score, confusion_matrix
+
+from utils import get_labels, get_datasets, check_processed_dir_existance
+
+
+par = argparse.ArgumentParser(description="ECG Convolutional " +
+                                           "Neural Network implementation with Tensorflow 2.0")
+
+par.add_argument("-lr", dest="learning_rate",
+                 type=float, default=0.001,
+                 help="Learning rate used by the model")
+
+par.add_argument("-e", dest="epochs",
+                 type=int, default=50,
+                 help="The number of epochs the model will train for")
+
+par.add_argument("-bs", dest="batch_size",
+                 type=int, default=32,
+                 help="The batch size of the model")
+
+par.add_argument("--display-step", dest="display_step",
+                 type=int, default=10,
+                 help="The display step")
+
+par.add_argument("--dropout", type=float, default=0.5,
+                 help="Dropout probability")
+
+par.add_argument("--restore", dest="restore_model",
+                 action="store_true", default=False,
+                 help="Restore the model previously saved")
+
+par.add_argument("--freeze", dest="freeze",
+                 action="store_true", default=False,
+                 help="Freezes the model")
+
+par.add_argument("--heart-diseases", nargs="+",
+                 dest="heart_diseases",
+                 default=["apnea-ecg", "svdb", "afdb"],
+                 choices=["apnea-ecg", "mitdb", "nsrdb", "svdb", "afdb"],
+                 help="Select the ECG diseases for the model")
+
+par.add_argument("--verbose", dest="verbose",
+                 action="store_true", default=False,
+                 help="Display information about minibatches")
+
+args = par.parse_args()
+
+# Parameters
+learning_rate = args.learning_rate
+epochs = args.epochs
+batch_size = args.batch_size
+display_step = args.display_step
+dropout = args.dropout
+restore_model = args.restore_model
+freeze = args.freeze
+heart_diseases = args.heart_diseases
+verbose = args.verbose
+
+# Network Parameters
+n_inputs = 350
+n_classes = len(heart_diseases)
+
+check_processed_dir_existance()
+
+
+class CNN:
+    def __init__(self):
+        self.datasets = get_datasets(heart_diseases, n_inputs)
+        self.label_data = get_labels(self.datasets)
+        self.callbacks = []
+
+        # Initialize callbacks
+        tensorboard_logs_path = "tensorboard_data/cnn/"
+        tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_logs_path,
+                                                     histogram_freq=1, write_graph=True,
+                                                     embeddings_freq=1)
+
+        # load_weights_on_restart will read the filepath of the weights if it exists and it will
+        # load the weights into the model
+        cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="saved_models/cnn/model.hdf5",
+                                                         save_best_only=True,
+                                                         save_weights_only=True,
+                                                         load_weights_on_restart=restore_model)
+
+        self.callbacks.extend([tb_callback, cp_callback])
+
+        self.set_data()
+        self.define_model()
+
+    def set_data(self):
+        dataset_len = []
+        for dataset in self.datasets:
+            dataset_len.append(len(dataset))
+
+        # validation on 10% of the training data
+        validation_size = 0.1
+
+        print("Validation percentage: {}%".format(validation_size*100))
+        print("Total samples: {}".format(sum(dataset_len)))
+        print("Heart diseases: {}".format(', '.join(heart_diseases)))
+
+        concat_dataset = np.concatenate(self.datasets)
+
+        self.split_data(concat_dataset, validation_size)
+
+        # Reshape input so that we can feed it to the conv layer
+        self.X_train = tf.reshape(self.X_train, shape=[-1, n_inputs, 1])
+        self.X_test = tf.reshape(self.X_test, shape=[-1, n_inputs, 1])
+        self.X_val = tf.reshape(self.X_val, shape=[-1, n_inputs, 1])
+
+        if verbose:
+            print("X_train shape: {}".format(self.X_train.shape))
+            print("Y_train shape: {}".format(self.Y_train.shape))
+            print("X_test shape: {}".format(self.X_test.shape))
+            print("Y_test shape: {}".format(self.Y_test.shape))
+            print("X_val shape: {}".format(self.X_val.shape))
+            print("Y_val shape: {}".format(self.Y_val.shape))
+
+    def define_model(self):
+
+         inputs = tf.keras.Input(shape=(n_inputs, 1), name='input')
+
+         # 64 filters, 10 kernel size
+         x = Conv1D(64, 10, activation='relu')(inputs)
+         x = MaxPool1D()(x)
+         x = BatchNormalization()(x)
+
+         x = Conv1D(128, 10, activation='relu')(x)
+         x = MaxPool1D()(x)
+         x = BatchNormalization()(x)
+
+         x = Conv1D(128, 10, activation='relu')(x)
+         x = MaxPool1D()(x)
+         x = BatchNormalization()(x)
+
+         x = Conv1D(256, 10, activation='relu')(x)
+         x = MaxPool1D()(x)
+         x = BatchNormalization()(x)
+
+         x = Flatten()(x)
+         x = Dense(1024, activation='relu', name='dense_1')(x)
+         x = BatchNormalization()(x)
+         x = Dropout(dropout)(x)
+
+         x = Dense(2048, activation='relu', name='dense_2')(x)
+         x = BatchNormalization()(x)
+         x = Dropout(dropout)(x)
+
+         outputs = Dense(n_classes, activation='softmax', name='predictions')(x)
+
+         self.cnn_model = tf.keras.Model(inputs=inputs, outputs=outputs)
+         optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
+         accuracy = CategoricalAccuracy()
+         self.cnn_model.compile(optimizer=optimizer, loss='categorical_crossentropy',
+                                metrics=[accuracy])
+
+    def split_data(self, dataset, validation_size):
+        """
+        Suffle then split training, testing and validation sets
+        """
+
+        # In order to use statify in train_test_split we can't use one hot encodings,
+        # so we convert to array of labels
+        label_data = np.argmax(self.label_data, axis=1)
+
+        # Splitting the dataset into train and test datasets
+        res = train_test_split(dataset, label_data,
+                               test_size=validation_size, shuffle=True,
+                               stratify=label_data)
+
+        self.X_train, self.X_test, self.Y_train, self.Y_test = res
+
+        # From the training dataset we further split it to obtain the validation dataset
+        res = train_test_split(self.X_train, self.Y_train,
+                               test_size=validation_size, stratify=self.Y_train)
+
+        self.X_train, self.X_val, self.Y_train, self.Y_val = res
+
+        # Convert the array of labels back into one hot encodings to be able to do training
+        self.Y_train = tf.keras.utils.to_categorical(self.Y_train)
+        self.Y_test = tf.keras.utils.to_categorical(self.Y_test)
+        self.Y_val = tf.keras.utils.to_categorical(self.Y_val)
+
+    def get_data(self):
+        return (self.X_train, self.X_test, self.X_val,
+                self.Y_train, self.Y_test, self.Y_val)
+
+
+def main():
+    # Construct model
+    model = CNN()
+    X_train, X_test, X_val, Y_train, Y_test, Y_val = model.get_data()
+
+    # Set start time
+    total_time = time.time()
+
+    print("-"*50)
+    if restore_model:
+        print("Restoring model: {}".format('saved_models/cnn/model.hdf5'))
+
+    # Train
+    model.cnn_model.fit(X_train, Y_train, batch_size=batch_size,
+                        epochs=epochs, validation_data=(X_val, Y_val),
+                        callbacks=model.callbacks)
+
+    print("-"*50)
+
+    # Total training time
+    print("Total training time: {0:.2f}s".format(time.time() - total_time))
+
+    # Test
+    model.cnn_model.evaluate(X_test, Y_test, batch_size=batch_size)
+    print("-"*50)
+    print("Testing results:")
+    y_pred = model.cnn_model.predict(X_test, batch_size=batch_size)
+
+    # The following scikit-learn methods only accept array of labels, not one hot encodings
+    y_pred = np.argmax(y_pred, axis=1)
+    y_true = np.argmax(Y_test, axis=1)
+
+    # Precision and recall could also be done as callbacks in the evaluate or fit function
+    print("Precision: {}".format(precision_score(y_true, y_pred, average='micro')))
+    print("Recall: {}".format(recall_score(y_true, y_pred, average='micro')))
+    print("Confusion matrix: \n{}".format(confusion_matrix(y_true, y_pred, labels=[0,1,2])))
+    disease_indexes = list(range(len(heart_diseases)))
+    print("Indexes {} correspond to labels {}".format(disease_indexes, [x for x in heart_diseases]))
+
+    print("-"*50)
+
+if __name__ == "__main__":
+    main()