a b/tensorflow_impl/cnn_tf2.py
1
import time
2
import argparse
3
4
import tensorflow as tf
5
import numpy as np
6
7
from tensorflow.keras.layers import Dense, Flatten, Conv1D, BatchNormalization, MaxPool1D, Dropout
8
from tensorflow.keras.metrics import CategoricalAccuracy
9
10
from sklearn.model_selection import train_test_split
11
from sklearn.metrics import precision_score, recall_score, confusion_matrix
12
13
from utils import get_labels, get_datasets, check_processed_dir_existance
14
15
16
par = argparse.ArgumentParser(description="ECG Convolutional " +
17
                                           "Neural Network implementation with Tensorflow 2.0")
18
19
par.add_argument("-lr", dest="learning_rate",
20
                 type=float, default=0.001,
21
                 help="Learning rate used by the model")
22
23
par.add_argument("-e", dest="epochs",
24
                 type=int, default=50,
25
                 help="The number of epochs the model will train for")
26
27
par.add_argument("-bs", dest="batch_size",
28
                 type=int, default=32,
29
                 help="The batch size of the model")
30
31
par.add_argument("--display-step", dest="display_step",
32
                 type=int, default=10,
33
                 help="The display step")
34
35
par.add_argument("--dropout", type=float, default=0.5,
36
                 help="Dropout probability")
37
38
par.add_argument("--restore", dest="restore_model",
39
                 action="store_true", default=False,
40
                 help="Restore the model previously saved")
41
42
par.add_argument("--freeze", dest="freeze",
43
                 action="store_true", default=False,
44
                 help="Freezes the model")
45
46
par.add_argument("--heart-diseases", nargs="+",
47
                 dest="heart_diseases",
48
                 default=["apnea-ecg", "svdb", "afdb"],
49
                 choices=["apnea-ecg", "mitdb", "nsrdb", "svdb", "afdb"],
50
                 help="Select the ECG diseases for the model")
51
52
par.add_argument("--verbose", dest="verbose",
53
                 action="store_true", default=False,
54
                 help="Display information about minibatches")
55
56
args = par.parse_args()
57
58
# Parameters
59
learning_rate = args.learning_rate
60
epochs = args.epochs
61
batch_size = args.batch_size
62
display_step = args.display_step
63
dropout = args.dropout
64
restore_model = args.restore_model
65
freeze = args.freeze
66
heart_diseases = args.heart_diseases
67
verbose = args.verbose
68
69
# Network Parameters
70
n_inputs = 350
71
n_classes = len(heart_diseases)
72
73
check_processed_dir_existance()
74
75
76
class CNN:
77
    def __init__(self):
78
        self.datasets = get_datasets(heart_diseases, n_inputs)
79
        self.label_data = get_labels(self.datasets)
80
        self.callbacks = []
81
82
        # Initialize callbacks
83
        tensorboard_logs_path = "tensorboard_data/cnn/"
84
        tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_logs_path,
85
                                                     histogram_freq=1, write_graph=True,
86
                                                     embeddings_freq=1)
87
88
        # load_weights_on_restart will read the filepath of the weights if it exists and it will
89
        # load the weights into the model
90
        cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="saved_models/cnn/model.hdf5",
91
                                                         save_best_only=True,
92
                                                         save_weights_only=True,
93
                                                         load_weights_on_restart=restore_model)
94
95
        self.callbacks.extend([tb_callback, cp_callback])
96
97
        self.set_data()
98
        self.define_model()
99
100
    def set_data(self):
101
        dataset_len = []
102
        for dataset in self.datasets:
103
            dataset_len.append(len(dataset))
104
105
        # validation on 10% of the training data
106
        validation_size = 0.1
107
108
        print("Validation percentage: {}%".format(validation_size*100))
109
        print("Total samples: {}".format(sum(dataset_len)))
110
        print("Heart diseases: {}".format(', '.join(heart_diseases)))
111
112
        concat_dataset = np.concatenate(self.datasets)
113
114
        self.split_data(concat_dataset, validation_size)
115
116
        # Reshape input so that we can feed it to the conv layer
117
        self.X_train = tf.reshape(self.X_train, shape=[-1, n_inputs, 1])
118
        self.X_test = tf.reshape(self.X_test, shape=[-1, n_inputs, 1])
119
        self.X_val = tf.reshape(self.X_val, shape=[-1, n_inputs, 1])
120
121
        if verbose:
122
            print("X_train shape: {}".format(self.X_train.shape))
123
            print("Y_train shape: {}".format(self.Y_train.shape))
124
            print("X_test shape: {}".format(self.X_test.shape))
125
            print("Y_test shape: {}".format(self.Y_test.shape))
126
            print("X_val shape: {}".format(self.X_val.shape))
127
            print("Y_val shape: {}".format(self.Y_val.shape))
128
129
    def define_model(self):
130
131
         inputs = tf.keras.Input(shape=(n_inputs, 1), name='input')
132
133
         # 64 filters, 10 kernel size
134
         x = Conv1D(64, 10, activation='relu')(inputs)
135
         x = MaxPool1D()(x)
136
         x = BatchNormalization()(x)
137
138
         x = Conv1D(128, 10, activation='relu')(x)
139
         x = MaxPool1D()(x)
140
         x = BatchNormalization()(x)
141
142
         x = Conv1D(128, 10, activation='relu')(x)
143
         x = MaxPool1D()(x)
144
         x = BatchNormalization()(x)
145
146
         x = Conv1D(256, 10, activation='relu')(x)
147
         x = MaxPool1D()(x)
148
         x = BatchNormalization()(x)
149
150
         x = Flatten()(x)
151
         x = Dense(1024, activation='relu', name='dense_1')(x)
152
         x = BatchNormalization()(x)
153
         x = Dropout(dropout)(x)
154
155
         x = Dense(2048, activation='relu', name='dense_2')(x)
156
         x = BatchNormalization()(x)
157
         x = Dropout(dropout)(x)
158
159
         outputs = Dense(n_classes, activation='softmax', name='predictions')(x)
160
161
         self.cnn_model = tf.keras.Model(inputs=inputs, outputs=outputs)
162
         optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
163
         accuracy = CategoricalAccuracy()
164
         self.cnn_model.compile(optimizer=optimizer, loss='categorical_crossentropy',
165
                                metrics=[accuracy])
166
167
    def split_data(self, dataset, validation_size):
168
        """
169
        Suffle then split training, testing and validation sets
170
        """
171
172
        # In order to use statify in train_test_split we can't use one hot encodings,
173
        # so we convert to array of labels
174
        label_data = np.argmax(self.label_data, axis=1)
175
176
        # Splitting the dataset into train and test datasets
177
        res = train_test_split(dataset, label_data,
178
                               test_size=validation_size, shuffle=True,
179
                               stratify=label_data)
180
181
        self.X_train, self.X_test, self.Y_train, self.Y_test = res
182
183
        # From the training dataset we further split it to obtain the validation dataset
184
        res = train_test_split(self.X_train, self.Y_train,
185
                               test_size=validation_size, stratify=self.Y_train)
186
187
        self.X_train, self.X_val, self.Y_train, self.Y_val = res
188
189
        # Convert the array of labels back into one hot encodings to be able to do training
190
        self.Y_train = tf.keras.utils.to_categorical(self.Y_train)
191
        self.Y_test = tf.keras.utils.to_categorical(self.Y_test)
192
        self.Y_val = tf.keras.utils.to_categorical(self.Y_val)
193
194
    def get_data(self):
195
        return (self.X_train, self.X_test, self.X_val,
196
                self.Y_train, self.Y_test, self.Y_val)
197
198
199
def main():
200
    # Construct model
201
    model = CNN()
202
    X_train, X_test, X_val, Y_train, Y_test, Y_val = model.get_data()
203
204
    # Set start time
205
    total_time = time.time()
206
207
    print("-"*50)
208
    if restore_model:
209
        print("Restoring model: {}".format('saved_models/cnn/model.hdf5'))
210
211
    # Train
212
    model.cnn_model.fit(X_train, Y_train, batch_size=batch_size,
213
                        epochs=epochs, validation_data=(X_val, Y_val),
214
                        callbacks=model.callbacks)
215
216
    print("-"*50)
217
218
    # Total training time
219
    print("Total training time: {0:.2f}s".format(time.time() - total_time))
220
221
    # Test
222
    model.cnn_model.evaluate(X_test, Y_test, batch_size=batch_size)
223
    print("-"*50)
224
    print("Testing results:")
225
    y_pred = model.cnn_model.predict(X_test, batch_size=batch_size)
226
227
    # The following scikit-learn methods only accept array of labels, not one hot encodings
228
    y_pred = np.argmax(y_pred, axis=1)
229
    y_true = np.argmax(Y_test, axis=1)
230
231
    # Precision and recall could also be done as callbacks in the evaluate or fit function
232
    print("Precision: {}".format(precision_score(y_true, y_pred, average='micro')))
233
    print("Recall: {}".format(recall_score(y_true, y_pred, average='micro')))
234
    print("Confusion matrix: \n{}".format(confusion_matrix(y_true, y_pred, labels=[0,1,2])))
235
    disease_indexes = list(range(len(heart_diseases)))
236
    print("Indexes {} correspond to labels {}".format(disease_indexes, [x for x in heart_diseases]))
237
238
    print("-"*50)
239
240
if __name__ == "__main__":
241
    main()