|
a |
|
b/tensorflow_impl/cnn_tf2.py |
|
|
1 |
import time |
|
|
2 |
import argparse |
|
|
3 |
|
|
|
4 |
import tensorflow as tf |
|
|
5 |
import numpy as np |
|
|
6 |
|
|
|
7 |
from tensorflow.keras.layers import Dense, Flatten, Conv1D, BatchNormalization, MaxPool1D, Dropout |
|
|
8 |
from tensorflow.keras.metrics import CategoricalAccuracy |
|
|
9 |
|
|
|
10 |
from sklearn.model_selection import train_test_split |
|
|
11 |
from sklearn.metrics import precision_score, recall_score, confusion_matrix |
|
|
12 |
|
|
|
13 |
from utils import get_labels, get_datasets, check_processed_dir_existance |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
par = argparse.ArgumentParser(description="ECG Convolutional " + |
|
|
17 |
"Neural Network implementation with Tensorflow 2.0") |
|
|
18 |
|
|
|
19 |
par.add_argument("-lr", dest="learning_rate", |
|
|
20 |
type=float, default=0.001, |
|
|
21 |
help="Learning rate used by the model") |
|
|
22 |
|
|
|
23 |
par.add_argument("-e", dest="epochs", |
|
|
24 |
type=int, default=50, |
|
|
25 |
help="The number of epochs the model will train for") |
|
|
26 |
|
|
|
27 |
par.add_argument("-bs", dest="batch_size", |
|
|
28 |
type=int, default=32, |
|
|
29 |
help="The batch size of the model") |
|
|
30 |
|
|
|
31 |
par.add_argument("--display-step", dest="display_step", |
|
|
32 |
type=int, default=10, |
|
|
33 |
help="The display step") |
|
|
34 |
|
|
|
35 |
par.add_argument("--dropout", type=float, default=0.5, |
|
|
36 |
help="Dropout probability") |
|
|
37 |
|
|
|
38 |
par.add_argument("--restore", dest="restore_model", |
|
|
39 |
action="store_true", default=False, |
|
|
40 |
help="Restore the model previously saved") |
|
|
41 |
|
|
|
42 |
par.add_argument("--freeze", dest="freeze", |
|
|
43 |
action="store_true", default=False, |
|
|
44 |
help="Freezes the model") |
|
|
45 |
|
|
|
46 |
par.add_argument("--heart-diseases", nargs="+", |
|
|
47 |
dest="heart_diseases", |
|
|
48 |
default=["apnea-ecg", "svdb", "afdb"], |
|
|
49 |
choices=["apnea-ecg", "mitdb", "nsrdb", "svdb", "afdb"], |
|
|
50 |
help="Select the ECG diseases for the model") |
|
|
51 |
|
|
|
52 |
par.add_argument("--verbose", dest="verbose", |
|
|
53 |
action="store_true", default=False, |
|
|
54 |
help="Display information about minibatches") |
|
|
55 |
|
|
|
56 |
args = par.parse_args() |
|
|
57 |
|
|
|
58 |
# Parameters |
|
|
59 |
learning_rate = args.learning_rate |
|
|
60 |
epochs = args.epochs |
|
|
61 |
batch_size = args.batch_size |
|
|
62 |
display_step = args.display_step |
|
|
63 |
dropout = args.dropout |
|
|
64 |
restore_model = args.restore_model |
|
|
65 |
freeze = args.freeze |
|
|
66 |
heart_diseases = args.heart_diseases |
|
|
67 |
verbose = args.verbose |
|
|
68 |
|
|
|
69 |
# Network Parameters |
|
|
70 |
n_inputs = 350 |
|
|
71 |
n_classes = len(heart_diseases) |
|
|
72 |
|
|
|
73 |
check_processed_dir_existance() |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
class CNN: |
|
|
77 |
def __init__(self): |
|
|
78 |
self.datasets = get_datasets(heart_diseases, n_inputs) |
|
|
79 |
self.label_data = get_labels(self.datasets) |
|
|
80 |
self.callbacks = [] |
|
|
81 |
|
|
|
82 |
# Initialize callbacks |
|
|
83 |
tensorboard_logs_path = "tensorboard_data/cnn/" |
|
|
84 |
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_logs_path, |
|
|
85 |
histogram_freq=1, write_graph=True, |
|
|
86 |
embeddings_freq=1) |
|
|
87 |
|
|
|
88 |
# load_weights_on_restart will read the filepath of the weights if it exists and it will |
|
|
89 |
# load the weights into the model |
|
|
90 |
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="saved_models/cnn/model.hdf5", |
|
|
91 |
save_best_only=True, |
|
|
92 |
save_weights_only=True, |
|
|
93 |
load_weights_on_restart=restore_model) |
|
|
94 |
|
|
|
95 |
self.callbacks.extend([tb_callback, cp_callback]) |
|
|
96 |
|
|
|
97 |
self.set_data() |
|
|
98 |
self.define_model() |
|
|
99 |
|
|
|
100 |
def set_data(self): |
|
|
101 |
dataset_len = [] |
|
|
102 |
for dataset in self.datasets: |
|
|
103 |
dataset_len.append(len(dataset)) |
|
|
104 |
|
|
|
105 |
# validation on 10% of the training data |
|
|
106 |
validation_size = 0.1 |
|
|
107 |
|
|
|
108 |
print("Validation percentage: {}%".format(validation_size*100)) |
|
|
109 |
print("Total samples: {}".format(sum(dataset_len))) |
|
|
110 |
print("Heart diseases: {}".format(', '.join(heart_diseases))) |
|
|
111 |
|
|
|
112 |
concat_dataset = np.concatenate(self.datasets) |
|
|
113 |
|
|
|
114 |
self.split_data(concat_dataset, validation_size) |
|
|
115 |
|
|
|
116 |
# Reshape input so that we can feed it to the conv layer |
|
|
117 |
self.X_train = tf.reshape(self.X_train, shape=[-1, n_inputs, 1]) |
|
|
118 |
self.X_test = tf.reshape(self.X_test, shape=[-1, n_inputs, 1]) |
|
|
119 |
self.X_val = tf.reshape(self.X_val, shape=[-1, n_inputs, 1]) |
|
|
120 |
|
|
|
121 |
if verbose: |
|
|
122 |
print("X_train shape: {}".format(self.X_train.shape)) |
|
|
123 |
print("Y_train shape: {}".format(self.Y_train.shape)) |
|
|
124 |
print("X_test shape: {}".format(self.X_test.shape)) |
|
|
125 |
print("Y_test shape: {}".format(self.Y_test.shape)) |
|
|
126 |
print("X_val shape: {}".format(self.X_val.shape)) |
|
|
127 |
print("Y_val shape: {}".format(self.Y_val.shape)) |
|
|
128 |
|
|
|
129 |
def define_model(self): |
|
|
130 |
|
|
|
131 |
inputs = tf.keras.Input(shape=(n_inputs, 1), name='input') |
|
|
132 |
|
|
|
133 |
# 64 filters, 10 kernel size |
|
|
134 |
x = Conv1D(64, 10, activation='relu')(inputs) |
|
|
135 |
x = MaxPool1D()(x) |
|
|
136 |
x = BatchNormalization()(x) |
|
|
137 |
|
|
|
138 |
x = Conv1D(128, 10, activation='relu')(x) |
|
|
139 |
x = MaxPool1D()(x) |
|
|
140 |
x = BatchNormalization()(x) |
|
|
141 |
|
|
|
142 |
x = Conv1D(128, 10, activation='relu')(x) |
|
|
143 |
x = MaxPool1D()(x) |
|
|
144 |
x = BatchNormalization()(x) |
|
|
145 |
|
|
|
146 |
x = Conv1D(256, 10, activation='relu')(x) |
|
|
147 |
x = MaxPool1D()(x) |
|
|
148 |
x = BatchNormalization()(x) |
|
|
149 |
|
|
|
150 |
x = Flatten()(x) |
|
|
151 |
x = Dense(1024, activation='relu', name='dense_1')(x) |
|
|
152 |
x = BatchNormalization()(x) |
|
|
153 |
x = Dropout(dropout)(x) |
|
|
154 |
|
|
|
155 |
x = Dense(2048, activation='relu', name='dense_2')(x) |
|
|
156 |
x = BatchNormalization()(x) |
|
|
157 |
x = Dropout(dropout)(x) |
|
|
158 |
|
|
|
159 |
outputs = Dense(n_classes, activation='softmax', name='predictions')(x) |
|
|
160 |
|
|
|
161 |
self.cnn_model = tf.keras.Model(inputs=inputs, outputs=outputs) |
|
|
162 |
optimizer = tf.keras.optimizers.Adam(lr=learning_rate) |
|
|
163 |
accuracy = CategoricalAccuracy() |
|
|
164 |
self.cnn_model.compile(optimizer=optimizer, loss='categorical_crossentropy', |
|
|
165 |
metrics=[accuracy]) |
|
|
166 |
|
|
|
167 |
def split_data(self, dataset, validation_size): |
|
|
168 |
""" |
|
|
169 |
Suffle then split training, testing and validation sets |
|
|
170 |
""" |
|
|
171 |
|
|
|
172 |
# In order to use statify in train_test_split we can't use one hot encodings, |
|
|
173 |
# so we convert to array of labels |
|
|
174 |
label_data = np.argmax(self.label_data, axis=1) |
|
|
175 |
|
|
|
176 |
# Splitting the dataset into train and test datasets |
|
|
177 |
res = train_test_split(dataset, label_data, |
|
|
178 |
test_size=validation_size, shuffle=True, |
|
|
179 |
stratify=label_data) |
|
|
180 |
|
|
|
181 |
self.X_train, self.X_test, self.Y_train, self.Y_test = res |
|
|
182 |
|
|
|
183 |
# From the training dataset we further split it to obtain the validation dataset |
|
|
184 |
res = train_test_split(self.X_train, self.Y_train, |
|
|
185 |
test_size=validation_size, stratify=self.Y_train) |
|
|
186 |
|
|
|
187 |
self.X_train, self.X_val, self.Y_train, self.Y_val = res |
|
|
188 |
|
|
|
189 |
# Convert the array of labels back into one hot encodings to be able to do training |
|
|
190 |
self.Y_train = tf.keras.utils.to_categorical(self.Y_train) |
|
|
191 |
self.Y_test = tf.keras.utils.to_categorical(self.Y_test) |
|
|
192 |
self.Y_val = tf.keras.utils.to_categorical(self.Y_val) |
|
|
193 |
|
|
|
194 |
def get_data(self): |
|
|
195 |
return (self.X_train, self.X_test, self.X_val, |
|
|
196 |
self.Y_train, self.Y_test, self.Y_val) |
|
|
197 |
|
|
|
198 |
|
|
|
199 |
def main(): |
|
|
200 |
# Construct model |
|
|
201 |
model = CNN() |
|
|
202 |
X_train, X_test, X_val, Y_train, Y_test, Y_val = model.get_data() |
|
|
203 |
|
|
|
204 |
# Set start time |
|
|
205 |
total_time = time.time() |
|
|
206 |
|
|
|
207 |
print("-"*50) |
|
|
208 |
if restore_model: |
|
|
209 |
print("Restoring model: {}".format('saved_models/cnn/model.hdf5')) |
|
|
210 |
|
|
|
211 |
# Train |
|
|
212 |
model.cnn_model.fit(X_train, Y_train, batch_size=batch_size, |
|
|
213 |
epochs=epochs, validation_data=(X_val, Y_val), |
|
|
214 |
callbacks=model.callbacks) |
|
|
215 |
|
|
|
216 |
print("-"*50) |
|
|
217 |
|
|
|
218 |
# Total training time |
|
|
219 |
print("Total training time: {0:.2f}s".format(time.time() - total_time)) |
|
|
220 |
|
|
|
221 |
# Test |
|
|
222 |
model.cnn_model.evaluate(X_test, Y_test, batch_size=batch_size) |
|
|
223 |
print("-"*50) |
|
|
224 |
print("Testing results:") |
|
|
225 |
y_pred = model.cnn_model.predict(X_test, batch_size=batch_size) |
|
|
226 |
|
|
|
227 |
# The following scikit-learn methods only accept array of labels, not one hot encodings |
|
|
228 |
y_pred = np.argmax(y_pred, axis=1) |
|
|
229 |
y_true = np.argmax(Y_test, axis=1) |
|
|
230 |
|
|
|
231 |
# Precision and recall could also be done as callbacks in the evaluate or fit function |
|
|
232 |
print("Precision: {}".format(precision_score(y_true, y_pred, average='micro'))) |
|
|
233 |
print("Recall: {}".format(recall_score(y_true, y_pred, average='micro'))) |
|
|
234 |
print("Confusion matrix: \n{}".format(confusion_matrix(y_true, y_pred, labels=[0,1,2]))) |
|
|
235 |
disease_indexes = list(range(len(heart_diseases))) |
|
|
236 |
print("Indexes {} correspond to labels {}".format(disease_indexes, [x for x in heart_diseases])) |
|
|
237 |
|
|
|
238 |
print("-"*50) |
|
|
239 |
|
|
|
240 |
if __name__ == "__main__": |
|
|
241 |
main() |