|
a/README.md |
|
b/README.md |
1 |
# ECG Time-Series Classification |
1 |
# ECG Time-Series Classification
|
2 |
The TensorFlow code in this project classifies a single heartbeat from an ECG recording. Three classification models were tested: a 1-D convolutional neural network (CNN); a recurrent neural network (RNN); and a Bayesian neural network (BNN) based on the CNN architecture. The CNN model is implemented in both Swift and Python; the RNN and BNN models are Python-only. |
2 |
The TensorFlow code in this project classifies a single heartbeat from an ECG recording. Three classification models were tested: a 1-D convolutional neural network (CNN); a recurrent neural network (RNN); and a Bayesian neural network (BNN) based on the CNN architecture. The CNN model is implemented in both Swift and Python; the RNN and BNN models are Python-only. |
3 |
|
3 |
|
4 |
## Data |
4 |
## Data
|
5 |
This analysis used segmented time-series data obtained from https://www.kaggle.com/shayanfazeli/heartbeat |
5 |
This analysis used segmented time-series data obtained from https://www.kaggle.com/shayanfazeli/heartbeat
|
6 |
* Time series are zero-padded, 187-element vectors containing the ECG lead II signal for one heartbeat. |
6 |
* Time series are zero-padded, 187-element vectors containing the ECG lead II signal for one heartbeat.
|
7 |
* Labels [0, ..., 4] represent normal heartbeats and 4 classes of arrhythmia ['N', 'S', 'V', 'F', 'Q']. |
7 |
* Labels [0, ..., 4] represent normal heartbeats and 4 classes of arrhythmia ['N', 'S', 'V', 'F', 'Q'].
|
8 |
* The class distribution is highly skewed. N = [90589, 2779, 7236, 803, 8039]. |
8 |
* The class distribution is highly skewed. N = [90589, 2779, 7236, 803, 8039].
|
9 |
* In `PreprocessECG.ipynb` I take 100 examples from each class for the test set, and use the remainder for the training set. Under-represented classes are upsampled to balance the class ratios for training. |
9 |
* In `PreprocessECG.ipynb` I take 100 examples from each class for the test set, and use the remainder for the training set. Under-represented classes are upsampled to balance the class ratios for training. |
10 |
|
10 |
|
11 |
Thank you to Shayan Fazeli for providing this data set. |
11 |
Thank you to Shayan Fazeli for providing this data set. |
12 |
|
12 |
|
13 |
## Models |
13 |
## Models
|
14 |
#### Convolutional Model |
14 |
#### Convolutional Model
|
15 |
* The convolutional model was taken from [Kachuee, Fazeli, & Sarrafzadeh \(2018\)](https://arxiv.org/pdf/1805.00794.pdf) |
15 |
* The convolutional model was taken from [Kachuee, Fazeli, & Sarrafzadeh \(2018\)](https://arxiv.org/pdf/1805.00794.pdf) |
16 |
|
16 |
|
17 |
Model consists of: |
17 |
Model consists of:
|
18 |
* An initial 1-D convolutional layer |
18 |
* An initial 1-D convolutional layer
|
19 |
* 5 repeated residual blocks (containing two 1-D convolutional layers with a passthrough connection and `same` padding; and a max pool layer) |
19 |
* 5 repeated residual blocks (containing two 1-D convolutional layers with a passthrough connection and `same` padding; and a max pool layer)
|
20 |
* A fully-connected layer |
20 |
* A fully-connected layer
|
21 |
* A linear layer with softmax output |
21 |
* A linear layer with softmax output
|
22 |
* No regularization was used except for early stopping |
22 |
* No regularization was used except for early stopping |
23 |
|
23 |
|
24 |
#### Recurrent Model |
24 |
#### Recurrent Model |
25 |
|
25 |
|
26 |
Model consists of: |
26 |
Model consists of:
|
27 |
* Two stacked bidirectional GRU layers (input is masked to the variable dimension of the heartbeat vector) |
27 |
* Two stacked bidirectional GRU layers (input is masked to the variable dimension of the heartbeat vector)
|
28 |
* Two fully-connected layers connected to the last output-pair of the downstream (bidirectional) GRU layer |
28 |
* Two fully-connected layers connected to the last output-pair of the downstream (bidirectional) GRU layer
|
29 |
* A linear layer with softmax output |
29 |
* A linear layer with softmax output
|
30 |
* Dropout regularization was used for the GRU layers |
30 |
* Dropout regularization was used for the GRU layers |
31 |
|
31 |
|
32 |
Since the model operates on segmented heartbeat samples, we can use a bidirectional RNN because the whole segment is available for processing at one time. It is also a more \"fair\" comparison with the CNN. |
32 |
Since the model operates on segmented heartbeat samples, we can use a bidirectional RNN because the whole segment is available for processing at one time. It is also a more \"fair\" comparison with the CNN. |
33 |
|
33 |
|
34 |
#### Bayesian Model |
34 |
#### Bayesian Model |
35 |
|
35 |
|
36 |
This model used the same network architecture as the convolutional (CNN) model above. However, the weights were stochastic, and posterior distributions of weights were trained using the Flipout method [\(Wen, Vicol, Ba, Tran, \& Grosse, 2018\)](https://arxiv.org/abs/1803.04386). |
36 |
This model used the same network architecture as the convolutional (CNN) model above. However, the weights were stochastic, and posterior distributions of weights were trained using the Flipout method [\(Wen, Vicol, Ba, Tran, \& Grosse, 2018\)](https://arxiv.org/abs/1803.04386). |
37 |
|
37 |
|
38 |
### Training |
38 |
### Training
|
39 |
The CNN and RNN models were trained for 8000 parameter updates with a mini-batch size of 200 using the Adam optimizer with exponential learning rate decay. See notebook for parameter values. The RNN model took about 10x longer to train (wall time) than the CNN model. |
39 |
The CNN and RNN models were trained for 8000 parameter updates with a mini-batch size of 200 using the Adam optimizer with exponential learning rate decay. See notebook for parameter values. The RNN model took about 10x longer to train (wall time) than the CNN model. |
40 |
|
40 |
|
41 |
The BNN model was trained for 3.5M parameter updates with a mini-batch size of 125 using the Adam optimizer with a fixed learning rate. The KL-divergence loss was annealed according to the [TensorFlow Probability example scheme](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/cifar10_bnn.py). See notebook for parameter values. |
41 |
The BNN model was trained for 3.5M parameter updates with a mini-batch size of 125 using the Adam optimizer with a fixed learning rate. The KL-divergence loss was annealed according to the [TensorFlow Probability example scheme](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/cifar10_bnn.py). See notebook for parameter values. |
42 |
|
42 |
|
43 |
## Results |
43 |
## Results
|
44 |
All models exhibited sufficient capacity to learn the training distribution with high accuracy. The error rates for all models were highest for the classes with the fewest examples. Collecting more data for the S- and F-type arrhythmias would likely increase the overall accuracy of the trained models. |
44 |
All models exhibited sufficient capacity to learn the training distribution with high accuracy. The error rates for all models were highest for the classes with the fewest examples. Collecting more data for the S- and F-type arrhythmias would likely increase the overall accuracy of the trained models. |
45 |
|
45 |
|
46 |
In contrast with [Kachuee, Fazeli, & Sarrafzadeh \(2018\)](https://arxiv.org/pdf/1805.00794.pdf), I chose to upsample the under-represented classes rather than augment data as we do not have a physiologically valid generative model for heartbeats. Kachuee _et al._ also used augmented data as part of their test set without justification and I did not. As a consequence, my test set is much smaller. That said, my results for the convolutional model appear to be consistent with theirs. |
46 |
In contrast with [Kachuee, Fazeli, & Sarrafzadeh \(2018\)](https://arxiv.org/pdf/1805.00794.pdf), I chose to upsample the under-represented classes rather than augment data as we do not have a physiologically valid generative model for heartbeats. Kachuee _et al._ also used augmented data as part of their test set without justification and I did not. As a consequence, my test set is much smaller. That said, my results for the convolutional model appear to be consistent with theirs. |
47 |
|
47 |
|
48 |
#### Convolutional Model |
48 |
#### Convolutional Model
|
49 |
``` |
49 |
```
|
50 |
class precision recall f1-score support |
50 |
class precision recall f1-score support |
51 |
|
51 |
|
52 |
0 0.88 0.98 0.92 100 |
52 |
0 0.88 0.98 0.92 100
|
53 |
1 0.98 0.91 0.94 100 |
53 |
1 0.98 0.91 0.94 100
|
54 |
2 0.91 0.97 0.94 100 |
54 |
2 0.91 0.97 0.94 100
|
55 |
3 0.98 0.87 0.92 100 |
55 |
3 0.98 0.87 0.92 100
|
56 |
4 1.00 0.99 0.99 100 |
56 |
4 1.00 0.99 0.99 100 |
57 |
|
57 |
|
58 |
micro avg 0.94 0.94 0.94 500 |
58 |
micro avg 0.94 0.94 0.94 500
|
59 |
macro avg 0.95 0.94 0.94 500 |
59 |
macro avg 0.95 0.94 0.94 500
|
60 |
weighted avg 0.95 0.94 0.94 500 |
60 |
weighted avg 0.95 0.94 0.94 500
|
61 |
``` |
61 |
```
|
62 |
Confusion Matrix |
62 |
Confusion Matrix |
63 |
|
63 |
|
64 |
 |
64 |
 |
65 |
|
65 |
|
66 |
#### Recurrent Model |
66 |
#### Recurrent Model
|
67 |
``` |
67 |
```
|
68 |
class precision recall f1-score support |
68 |
class precision recall f1-score support |
69 |
|
69 |
|
70 |
0 0.84 0.97 0.90 100 |
70 |
0 0.84 0.97 0.90 100
|
71 |
1 0.98 0.89 0.93 100 |
71 |
1 0.98 0.89 0.93 100
|
72 |
2 0.91 0.92 0.92 100 |
72 |
2 0.91 0.92 0.92 100
|
73 |
3 0.98 0.89 0.93 100 |
73 |
3 0.98 0.89 0.93 100
|
74 |
4 0.97 0.99 0.98 100 |
74 |
4 0.97 0.99 0.98 100 |
75 |
|
75 |
|
76 |
micro avg 0.93 0.93 0.93 500 |
76 |
micro avg 0.93 0.93 0.93 500
|
77 |
macro avg 0.94 0.93 0.93 500 |
77 |
macro avg 0.94 0.93 0.93 500
|
78 |
weighted avg 0.94 0.93 0.93 500 |
78 |
weighted avg 0.94 0.93 0.93 500
|
79 |
``` |
79 |
```
|
80 |
Confusion Matrix |
80 |
Confusion Matrix |
81 |
|
81 |
|
82 |
 |
82 |
 |
83 |
|
83 |
|
84 |
#### Bayesian Model |
84 |
#### Bayesian Model
|
85 |
For the Bayesian model, I obtained a Monte Carlo estimate for the most probable class. This class was then evaluated as above based on precision, recall, and the confusion matrix. |
85 |
For the Bayesian model, I obtained a Monte Carlo estimate for the most probable class. This class was then evaluated as above based on precision, recall, and the confusion matrix. |
86 |
|
86 |
|
87 |
``` |
87 |
```
|
88 |
class precision recall f1-score support |
88 |
class precision recall f1-score support |
89 |
|
89 |
|
90 |
0 0.88 0.98 0.92 100 |
90 |
0 0.88 0.98 0.92 100
|
91 |
1 0.97 0.91 0.94 100 |
91 |
1 0.97 0.91 0.94 100
|
92 |
2 0.92 0.98 0.95 100 |
92 |
2 0.92 0.98 0.95 100
|
93 |
3 0.99 0.88 0.93 100 |
93 |
3 0.99 0.88 0.93 100
|
94 |
4 1.00 0.99 0.99 100 |
94 |
4 1.00 0.99 0.99 100 |
95 |
|
95 |
|
96 |
micro avg 0.95 0.95 0.95 500 |
96 |
micro avg 0.95 0.95 0.95 500
|
97 |
macro avg 0.95 0.95 0.95 500 |
97 |
macro avg 0.95 0.95 0.95 500
|
98 |
weighted avg 0.95 0.95 0.95 500 |
98 |
weighted avg 0.95 0.95 0.95 500
|
99 |
``` |
99 |
```
|
100 |
Confusion Matrix |
100 |
Confusion Matrix |
101 |
|
101 |
|
102 |
 |
102 |
 |
103 |
|
103 |
|
104 |
## Discussion |
104 |
## Discussion
|
105 |
#### CNN versus RNN |
105 |
#### CNN versus RNN
|
106 |
The CNN model has 53,957 parameters and the RNN model has 240,293. Moreover, the serial nature of the RNN causes it to be less parallelizable than the CNN. Given that the CNN is slightly more accurate than the RNN, it provides an all-around better solution. |
106 |
The CNN model has 53,957 parameters and the RNN model has 240,293. Moreover, the serial nature of the RNN causes it to be less parallelizable than the CNN. Given that the CNN is slightly more accurate than the RNN, it provides an all-around better solution. |
107 |
|
107 |
|
108 |
#### Maximum Likelihood versus Bayesian Estimate |
108 |
#### Maximum Likelihood versus Bayesian Estimate
|
109 |
The Bayesian model has slightly better performance than the standard CNN model with its maximum likelihood estimate (MLE). However, due to the small test-set size, this difference in performance may not be statistically significant. Still, the KL-divergence term in the loss for the Bayesian model should have a regularizing effect and allow the BNN model to generalize better. |
109 |
The Bayesian model has slightly better performance than the standard CNN model with its maximum likelihood estimate (MLE). However, due to the small test-set size, this difference in performance may not be statistically significant. Still, the KL-divergence term in the loss for the Bayesian model should have a regularizing effect and allow the BNN model to generalize better. |
110 |
|
110 |
|
111 |
#### Probability Estimation |
111 |
#### Probability Estimation
|
112 |
The Bayesian neural network lore states that Bayesian networks produce better probability estimates than their standard \(maximum likelihood\) NN counterparts. We can check this by comparing the accuracy of the softmax \"probability\" estimate in the standard CNN model with the accuracy of the Monte Carlo probability estimate from the Bayesian network. |
112 |
The Bayesian neural network lore states that Bayesian networks produce better probability estimates than their standard \(maximum likelihood\) NN counterparts. We can check this by comparing the accuracy of the softmax \"probability\" estimate in the standard CNN model with the accuracy of the Monte Carlo probability estimate from the Bayesian network. |
113 |
|
113 |
|
114 |
Fraction of correct CNN classifications versus softmax \"probability\" estimate, binned by estimate value. Error bars are 68% binomial confidence limits. The one-to-one line is the expected value if the estimate were perfect. |
114 |
Fraction of correct CNN classifications versus softmax \"probability\" estimate, binned by estimate value. Error bars are 68% binomial confidence limits. The one-to-one line is the expected value if the estimate were perfect. |
115 |
|
115 |
|
116 |
 |
116 |
 |
117 |
|
117 |
|
118 |
Fraction of correct Bayesian network classifications versus Monte Carlo probability estimate, binned by estimate value. Error bars are 68% binomial confidence limits. The one-to-one line is the expected value if the estimate were perfect. |
118 |
Fraction of correct Bayesian network classifications versus Monte Carlo probability estimate, binned by estimate value. Error bars are 68% binomial confidence limits. The one-to-one line is the expected value if the estimate were perfect. |
119 |
|
119 |
|
120 |
 |
120 |

|
121 |
|
121 |
|
122 |
It is clear from the plots that the standard \(maximum likelihood\) network is estimating probability at least as well as the Bayesian network. |
122 |
It is clear from the plots that the standard \(maximum likelihood\) network is estimating probability at least as well as the Bayesian network. |
123 |
|
123 |
|
124 |
## Files |
124 |
## Files
|
125 |
* `PreprocessECG.ipynb` is a Jupyter notebook used to format and balance the data. |
125 |
* `PreprocessECG.ipynb` is a Jupyter notebook used to format and balance the data.
|
126 |
* `ClassifyECG.ipynb` is a Jupyter notebook containing the CNN and RNN classification models, as well as training and evaluation code. |
126 |
* `ClassifyECG.ipynb` is a Jupyter notebook containing the CNN and RNN classification models, as well as training and evaluation code.
|
127 |
* `BayesClassifierECG.ipynb` is a Jupyter notebook containing the Bayesian classification model, as well as training and evaluation code. |
127 |
* `BayesClassifierECG.ipynb` is a Jupyter notebook containing the Bayesian classification model, as well as training and evaluation code.
|
128 |
* `ECG.xcodeproj` is an Xcode 11 project file that builds the Swift source from the `ECG` subdirectory to train the CNN model. |
128 |
* `ECG.xcodeproj` is an Xcode 11 project file that builds the Swift source from the `ECG` subdirectory to train the CNN model. |
129 |
|
129 |
|
130 |
## Implementation Notes |
130 |
## Implementation Notes
|
131 |
* Python implementation tested with Python 3.6.7, TensorFlow 1.13.1, and TensorFlow Probability 0.6.0 |
131 |
* Python implementation tested with Python 3.6.7, TensorFlow 1.13.1, and TensorFlow Probability 0.6.0
|
132 |
* Swift implementation tested in Xcode 11 with Swift for Tensorflow toolchain 0.4.0 |
132 |
* Swift implementation tested in Xcode 11 with Swift for Tensorflow toolchain 0.4.0
|