a b/BayesClassifierECG.ipynb
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "name": "BayesClassifierECG-v1.ipynb",
7
      "version": "0.3.2",
8
      "provenance": [],
9
      "collapsed_sections": [
10
        "x2Q8Cy8HZ8dB"
11
      ]
12
    },
13
    "kernelspec": {
14
      "name": "python3",
15
      "display_name": "Python 3"
16
    },
17
    "accelerator": "GPU"
18
  },
19
  "cells": [
20
    {
21
      "metadata": {
22
        "id": "DfByttJIWbgQ",
23
        "colab_type": "text"
24
      },
25
      "cell_type": "markdown",
26
      "source": [
27
        "# Bayesian Classification for ECG Time-Series\n",
28
        "\n",
29
        "> Copyright 2019 Dave Fernandes. All Rights Reserved.\n",
30
        "> \n",
31
        "> Licensed under the Apache License, Version 2.0 (the \"License\");\n",
32
        "> you may not use this file except in compliance with the License.\n",
33
        "> You may obtain a copy of the License at\n",
34
        ">\n",
35
        "> http://www.apache.org/licenses/LICENSE-2.0\n",
36
        ">  \n",
37
        "> Unless required by applicable law or agreed to in writing, software\n",
38
        "> distributed under the License is distributed on an \"AS IS\" BASIS,\n",
39
        "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
40
        "> See the License for the specific language governing permissions and\n",
41
        "> limitations under the License."
42
      ]
43
    },
44
    {
45
      "metadata": {
46
        "id": "MTd4aHhYWhdN",
47
        "colab_type": "text"
48
      },
49
      "cell_type": "markdown",
50
      "source": [
51
        "### Overview\n",
52
        "This notebook classifies time-series for segmented heartbeats from ECG lead II recordings. Either of two models (CNN or RNN) can be selected from training and evaluation.\n",
53
        "- Data for this analysis should be prepared using the `PreprocessECG.ipynb` notebook from this project.\n",
54
        "- Original data is from: https://www.kaggle.com/shayanfazeli/heartbeat"
55
      ]
56
    },
57
    {
58
      "metadata": {
59
        "id": "hjmdX-HbWdeI",
60
        "colab_type": "code",
61
        "colab": {}
62
      },
63
      "cell_type": "code",
64
      "source": [
65
        "import numpy as np\n",
66
        "import tensorflow as tf\n",
67
        "import tensorflow.keras.layers as keras\n",
68
        "import tensorflow_probability as tfp\n",
69
        "from tensorflow_probability import distributions as tfd\n",
70
        "import matplotlib.pyplot as plt\n",
71
        "import pickle\n",
72
        "\n",
73
        "from google.colab import drive\n",
74
        "drive.mount('/content/drive')\n",
75
        "\n",
76
        "TRAIN_SET = '/content/drive/My Drive/Colab Notebooks/Data/train_set.pickle'\n",
77
        "TEST_SET = '/content/drive/My Drive/Colab Notebooks/Data/test_set.pickle'\n",
78
        "\n",
79
        "BATCH_SIZE = 125\n",
80
        "\n",
81
        "with open(TEST_SET, 'rb') as file:\n",
82
        "    test_set = pickle.load(file)\n",
83
        "    x_test = test_set['x'].astype('float32')\n",
84
        "    y_test = test_set['y'].astype('int32')\n",
85
        "\n",
86
        "with open(TRAIN_SET, 'rb') as file:\n",
87
        "    train_set = pickle.load(file)\n",
88
        "    x_train = train_set['x'].astype('float32')\n",
89
        "    y_train = train_set['y'].astype('int32')\n",
90
        "\n",
91
        "TRAIN_COUNT = len(y_train)\n",
92
        "TEST_COUNT = len(y_test)\n",
93
        "BATCHES_PER_EPOCH = TRAIN_COUNT // BATCH_SIZE\n",
94
        "TEST_BATCHES_PER_EPOCH = TEST_COUNT // BATCH_SIZE\n",
95
        "INPUT_SIZE = np.shape(x_train)[1]\n",
96
        "\n",
97
        "print('Train count:', TRAIN_COUNT, 'Test count:', TEST_COUNT)"
98
      ],
99
      "execution_count": 0,
100
      "outputs": []
101
    },
102
    {
103
      "metadata": {
104
        "id": "x2Q8Cy8HZ8dB",
105
        "colab_type": "text"
106
      },
107
      "cell_type": "markdown",
108
      "source": [
109
        "### Input Datasets"
110
      ]
111
    },
112
    {
113
      "metadata": {
114
        "id": "uiTIHzr5Wsmn",
115
        "colab_type": "code",
116
        "colab": {}
117
      },
118
      "cell_type": "code",
119
      "source": [
120
        "def combined_dataset(features, labels):\n",
121
        "    assert features.shape[0] == labels.shape[0]\n",
122
        "    dataset = tf.data.Dataset.from_tensor_slices((np.expand_dims(features, axis=-1), labels))\n",
123
        "    return dataset\n",
124
        "\n",
125
        "# For training\n",
126
        "def train_input_fn():\n",
127
        "    dataset = combined_dataset(x_train, y_train)\n",
128
        "    return dataset.shuffle(TRAIN_COUNT, reshuffle_each_iteration=True).repeat().batch(BATCH_SIZE, drop_remainder=True).prefetch(1)\n",
129
        "\n",
130
        "# For evaluation and metrics\n",
131
        "def eval_input_fn():\n",
132
        "    dataset = combined_dataset(x_test, y_test)\n",
133
        "    return dataset.repeat().batch(BATCH_SIZE).prefetch(1)\n",
134
        "\n",
135
        "training_batches = train_input_fn()\n",
136
        "training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)\n",
137
        "heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(eval_input_fn())\n",
138
        "\n",
139
        "# Combine these into a feedable iterator that can switch between training\n",
140
        "# and validation inputs.\n",
141
        "handle = tf.compat.v1.placeholder(tf.string, shape=[])\n",
142
        "feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(handle, training_batches.output_types, training_batches.output_shapes)\n",
143
        "series, labels = feedable_iterator.get_next()"
144
      ],
145
      "execution_count": 0,
146
      "outputs": []
147
    },
148
    {
149
      "metadata": {
150
        "id": "PnGsC48GaGTk",
151
        "colab_type": "text"
152
      },
153
      "cell_type": "markdown",
154
      "source": [
155
        "### Define the model\n",
156
        "#### Bayesian CNN Model\n",
157
        "* The convolutional model is taken from: https://arxiv.org/pdf/1805.00794.pdf\n",
158
        "\n",
159
        "Model consists of:\n",
160
        "* An initial 1-D convolutional layer\n",
161
        "* 5 repeated residual blocks (`same` padding)\n",
162
        "* A fully-connected layer\n",
163
        "* A linear layer with softmax output\n",
164
        "* Flipout layers are used instead of standard layers"
165
      ]
166
    },
167
    {
168
      "metadata": {
169
        "id": "Q8XdxTVYaO1q",
170
        "colab_type": "code",
171
        "colab": {}
172
      },
173
      "cell_type": "code",
174
      "source": [
175
        "KL_ANNEALING = 30\n",
176
        "\n",
177
        "MODEL_PATH = '/content/drive/My Drive/Colab Notebooks/Models/BayesianCNN/BNN.tfmodel'\n",
178
        "\n",
179
        "def kernel_prior(dtype, shape, name, trainable, add_variable_fn):\n",
180
        "    return tfp.layers.default_multivariate_normal_fn(dtype, shape, name, trainable, add_variable_fn)\n",
181
        "#     return tfd.Horseshoe(scale=5.0)\n",
182
        "\n",
183
        "#     mix = 0.75\n",
184
        "#     mixture = tfd.Mixture(name=name,\n",
185
        "#         cat=tfd.Deterministic([mix, 1. - mix]),\n",
186
        "#         components=[tfd.Normal(loc=0., scale=1.), tfd.Normal(loc=0., scale=7.)])\n",
187
        "#     return mixture\n",
188
        "\n",
189
        "def conv_unit(unit, input_layer):\n",
190
        "    s = '_' + str(unit)\n",
191
        "    layer = tfp.layers.Convolution1DFlipout(name='Conv1' + s, filters=32, kernel_size=5, strides=1, padding='same', activation='relu', kernel_prior_fn=kernel_prior)(input_layer)\n",
192
        "    layer = tfp.layers.Convolution1DFlipout(name='Conv2' + s, filters=32, kernel_size=5, strides=1, padding='same', activation=None, kernel_prior_fn=kernel_prior)(layer )\n",
193
        "    layer = keras.Add(name='ResidualSum' + s)([layer, input_layer])\n",
194
        "    layer = keras.Activation(\"relu\", name='Act' + s)(layer)\n",
195
        "    layer = keras.MaxPooling1D(name='MaxPool' + s, pool_size=5, strides=2)(layer)\n",
196
        "    return layer\n",
197
        "\n",
198
        "def model_fn(input_shape):\n",
199
        "    time_series = tf.keras.layers.Input(shape=input_shape, dtype='float32')\n",
200
        "    \n",
201
        "    current_layer = tfp.layers.Convolution1DFlipout(filters=32, kernel_size=5, strides=1, kernel_prior_fn=kernel_prior)(time_series)\n",
202
        "\n",
203
        "    for i in range(5):\n",
204
        "        current_layer = conv_unit(i + 1, current_layer)\n",
205
        "\n",
206
        "    current_layer = keras.Flatten()(current_layer)\n",
207
        "    current_layer = tfp.layers.DenseFlipout(32, name='FC1', activation='relu', kernel_prior_fn=kernel_prior)(current_layer)\n",
208
        "    logits = tfp.layers.DenseFlipout(5, name='Output', kernel_prior_fn=kernel_prior)(current_layer)\n",
209
        "    \n",
210
        "    model = tf.keras.Model(inputs=time_series, outputs=logits, name='bayes_cnn_model')\n",
211
        "    return model\n",
212
        "  \n",
213
        "# Compute the negative Evidence Lower Bound (ELBO) loss\n",
214
        "t = tf.compat.v1.Variable(0.0)\n",
215
        "\n",
216
        "def loss_fn(labels, logits):\n",
217
        "    labels_distribution = tfd.Categorical(logits=logits)\n",
218
        "\n",
219
        "    # Perform KL annealing. The optimal number of annealing steps\n",
220
        "    # depends on the dataset and architecture.\n",
221
        "    kl_regularizer = t / (KL_ANNEALING * BATCHES_PER_EPOCH)\n",
222
        "\n",
223
        "    # Compute the -ELBO as the loss. The kl term is annealed from 0 to 1 over\n",
224
        "    # the epochs specified by the kl_annealing flag.\n",
225
        "    log_likelihood = labels_distribution.log_prob(labels)\n",
226
        "    neg_log_likelihood = -tf.reduce_mean(input_tensor=log_likelihood)\n",
227
        "    kl = sum(model.losses) / len(x_train) * tf.minimum(1.0, kl_regularizer)\n",
228
        "    return neg_log_likelihood + kl, kl, kl_regularizer, labels_distribution\n",
229
        "\n",
230
        "model = model_fn([INPUT_SIZE, 1])\n",
231
        "model.summary()"
232
      ],
233
      "execution_count": 0,
234
      "outputs": []
235
    },
236
    {
237
      "metadata": {
238
        "id": "1h96cjFraUo_",
239
        "colab_type": "text"
240
      },
241
      "cell_type": "markdown",
242
      "source": [
243
        "### Train model"
244
      ]
245
    },
246
    {
247
      "metadata": {
248
        "id": "9uHtgxoLynkU",
249
        "colab_type": "code",
250
        "colab": {}
251
      },
252
      "cell_type": "code",
253
      "source": [
254
        "INITIAL_LEARNING_RATE = 0.0001\n",
255
        "EPOCHS = 100\n",
256
        "\n",
257
        "assert (EPOCHS > 0)\n",
258
        "\n",
259
        "logits = model(series)\n",
260
        "loss, kl, kl_reg, labels_distribution = loss_fn(labels, logits)\n",
261
        "\n",
262
        "# Build metrics for evaluation. Predictions are formed from a single forward\n",
263
        "# pass of the probabilistic layers. They are cheap but noisy\n",
264
        "# predictions.\n",
265
        "predictions = tf.argmax(input=logits, axis=1)\n",
266
        "with tf.compat.v1.name_scope(\"train\"):\n",
267
        "    train_accuracy, train_accuracy_update_op = tf.compat.v1.metrics.accuracy(labels=labels, predictions=predictions)\n",
268
        "    opt = tf.compat.v1.train.AdamOptimizer(INITIAL_LEARNING_RATE)\n",
269
        "    train_op = opt.minimize(loss)\n",
270
        "    update_step_op = tf.compat.v1.assign(t, t + 1)\n",
271
        "\n",
272
        "with tf.compat.v1.name_scope(\"valid\"):\n",
273
        "    valid_accuracy, valid_accuracy_update_op = tf.compat.v1.metrics.accuracy(labels=labels, predictions=predictions)\n",
274
        "\n",
275
        "init_op = tf.group(tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer())\n",
276
        "\n",
277
        "stream_vars_valid = [\n",
278
        "    v for v in tf.compat.v1.local_variables() if \"valid/\" in v.name\n",
279
        "]\n",
280
        "reset_valid_op = tf.compat.v1.variables_initializer(stream_vars_valid)\n",
281
        "\n",
282
        "with tf.compat.v1.Session() as sess:\n",
283
        "    sess.run(init_op)\n",
284
        "\n",
285
        "    # Run the training loop\n",
286
        "    train_handle = sess.run(training_iterator.string_handle())\n",
287
        "    heldout_handle = sess.run(heldout_iterator.string_handle())\n",
288
        "    training_steps = EPOCHS * BATCHES_PER_EPOCH\n",
289
        "    \n",
290
        "    for step in range(training_steps):\n",
291
        "        _ = sess.run([train_op, train_accuracy_update_op, update_step_op], feed_dict={handle: train_handle})\n",
292
        "\n",
293
        "        # Manually print the frequency\n",
294
        "        if step % (BATCHES_PER_EPOCH // 5) == 0:\n",
295
        "            loss_value, accuracy_value, kl_value, kl_reg_value = sess.run([loss, train_accuracy, kl, kl_reg], feed_dict={handle: train_handle})\n",
296
        "            print(\"   Loss: {:.3f} Accuracy: {:.3f} KL: {:.3f} KL-reg: {:.3f}\".format(loss_value, accuracy_value, kl_value, kl_reg_value))\n",
297
        "\n",
298
        "        if (step + 1) % BATCHES_PER_EPOCH == 0:\n",
299
        "            # Calculate validation accuracy\n",
300
        "            for _ in range(TEST_BATCHES_PER_EPOCH):\n",
301
        "                sess.run(valid_accuracy_update_op, feed_dict={handle: heldout_handle})\n",
302
        "            \n",
303
        "            valid_value = sess.run(valid_accuracy, feed_dict={handle: heldout_handle})\n",
304
        "            print(\"Epoch: {:>3d} Validation Accuracy: {:.3f}\".format((step + 1) // BATCHES_PER_EPOCH, valid_value))\n",
305
        "\n",
306
        "            sess.run(reset_valid_op)\n",
307
        "            \n",
308
        "    model.save_weights(MODEL_PATH)"
309
      ],
310
      "execution_count": 0,
311
      "outputs": []
312
    },
313
    {
314
      "metadata": {
315
        "id": "Ml9cWK8j0vEU",
316
        "colab_type": "text"
317
      },
318
      "cell_type": "markdown",
319
      "source": [
320
        "### Evaluate model"
321
      ]
322
    },
323
    {
324
      "metadata": {
325
        "id": "p40qKgiIIVNd",
326
        "colab_type": "code",
327
        "colab": {}
328
      },
329
      "cell_type": "code",
330
      "source": [
331
        "NUM_MONTE_CARLO = 1000\n",
332
        "\n",
333
        "model.load_weights(MODEL_PATH)\n",
334
        "\n",
335
        "mc_counts = np.zeros((TEST_COUNT, 5))\n",
336
        "x = np.expand_dims(x_test, -1)\n",
337
        "sample_index = np.arange(TEST_COUNT)\n",
338
        "\n",
339
        "for i in range(NUM_MONTE_CARLO):\n",
340
        "    y_pred = np.argmax(model.predict(x), axis=1)\n",
341
        "    mc_counts[sample_index, y_pred] += 1\n",
342
        "    \n",
343
        "y_pred = np.argmax(mc_counts, axis=1)\n",
344
        "y_prob = mc_counts[sample_index, y_pred] / NUM_MONTE_CARLO\n",
345
        "\n",
346
        "y_prob_correct = y_prob[y_pred == y_test]\n",
347
        "y_prob_mis = y_prob[y_pred != y_test]"
348
      ],
349
      "execution_count": 0,
350
      "outputs": []
351
    },
352
    {
353
      "metadata": {
354
        "id": "dWh_nMZN9J0N",
355
        "colab_type": "text"
356
      },
357
      "cell_type": "markdown",
358
      "source": [
359
        "### Check probability estimates"
360
      ]
361
    },
362
    {
363
      "metadata": {
364
        "id": "K0lRvUS9gQIk",
365
        "colab_type": "code",
366
        "colab": {}
367
      },
368
      "cell_type": "code",
369
      "source": [
370
        "from astropy.stats import binom_conf_interval\n",
371
        "\n",
372
        "_, _, _ = plt.hist(y_prob, 10, (0, 1))\n",
373
        "plt.xlabel('Belief')\n",
374
        "plt.ylabel('Count')\n",
375
        "plt.title('All Predictions')\n",
376
        "plt.show();\n",
377
        "\n",
378
        "n_all, bins = np.histogram(y_prob, 10, (0, 1))\n",
379
        "n_correct, bins = np.histogram(y_prob_correct, 10, (0, 1))\n",
380
        "\n",
381
        "f_correct = n_correct / np.clip(n_all, 1, None)\n",
382
        "f_bins = 0.5 * (bins[:-1] + bins[1:])\n",
383
        "\n",
384
        "n_correct = n_correct[n_all > 0]\n",
385
        "n_total = n_all[n_all > 0]\n",
386
        "f_correct = n_correct / n_total\n",
387
        "f_bins = f_bins[n_all > 0]\n",
388
        "\n",
389
        "lower_bound, upper_bound = binom_conf_interval(n_correct, n_total)\n",
390
        "error_bars = np.array([f_correct - lower_bound, upper_bound - f_correct])\n",
391
        "\n",
392
        "plt.plot([0., 1.], [0., 1.])\n",
393
        "plt.errorbar(f_bins, f_correct, yerr=error_bars, fmt='o')\n",
394
        "plt.xlabel('Monte Carlo Probability')\n",
395
        "plt.ylabel('Frequency')\n",
396
        "plt.title('Correct Predictions')\n",
397
        "plt.show();"
398
      ],
399
      "execution_count": 0,
400
      "outputs": []
401
    },
402
    {
403
      "metadata": {
404
        "id": "ggvEGluM9TEw",
405
        "colab_type": "text"
406
      },
407
      "cell_type": "markdown",
408
      "source": [
409
        "### Compute metrics"
410
      ]
411
    },
412
    {
413
      "metadata": {
414
        "id": "TeV6NtwnP0Lm",
415
        "colab_type": "code",
416
        "colab": {}
417
      },
418
      "cell_type": "code",
419
      "source": [
420
        "import sklearn.metrics as skm\n",
421
        "import seaborn\n",
422
        "\n",
423
        "# Classification report\n",
424
        "report = skm.classification_report(y_test, y_pred)\n",
425
        "print(report)\n",
426
        "\n",
427
        "# Confusion matrix\n",
428
        "cm = skm.confusion_matrix(y_test, y_pred)\n",
429
        "seaborn.heatmap(cm, annot=True,annot_kws={\"size\": 16})"
430
      ],
431
      "execution_count": 0,
432
      "outputs": []
433
    },
434
    {
435
      "metadata": {
436
        "id": "MnQeNtCkV3Uv",
437
        "colab_type": "code",
438
        "colab": {}
439
      },
440
      "cell_type": "code",
441
      "source": [
442
        ""
443
      ],
444
      "execution_count": 0,
445
      "outputs": []
446
    }
447
  ]
448
}