Diff of /README.md [000000] .. [c49678]

Switch to unified view

a b/README.md
1
# ECG Time-Series Classification
2
The TensorFlow code in this project classifies a single heartbeat from an ECG recording. Three classification models were tested: a 1-D convolutional neural network (CNN); a recurrent neural network (RNN); and a Bayesian neural network (BNN) based on the CNN architecture. The CNN model is implemented in both Swift and Python; the RNN and BNN models are Python-only.
3
4
## Data
5
This analysis used segmented time-series data obtained from https://www.kaggle.com/shayanfazeli/heartbeat
6
* Time series are zero-padded, 187-element vectors containing the ECG lead II signal for one heartbeat.
7
* Labels [0, ..., 4] represent normal heartbeats and 4 classes of arrhythmia ['N', 'S', 'V', 'F', 'Q'].
8
* The class distribution is highly skewed. N = [90589, 2779, 7236, 803, 8039].
9
* In `PreprocessECG.ipynb` I take 100 examples from each class for the test set, and use the remainder for the training set. Under-represented classes are upsampled to balance the class ratios for training.
10
11
Thank you to Shayan Fazeli for providing this data set.
12
13
## Models
14
#### Convolutional Model
15
* The convolutional model was taken from [Kachuee, Fazeli, & Sarrafzadeh \(2018\)](https://arxiv.org/pdf/1805.00794.pdf)
16
17
Model consists of:
18
* An initial 1-D convolutional layer
19
* 5 repeated residual blocks (containing two 1-D convolutional layers with a passthrough connection and `same` padding; and a max pool layer)
20
* A fully-connected layer
21
* A linear layer with softmax output
22
* No regularization was used except for early stopping
23
24
#### Recurrent Model
25
26
Model consists of:
27
* Two stacked bidirectional GRU layers (input is masked to the variable dimension of the heartbeat vector)
28
* Two fully-connected layers connected to the last output-pair of the downstream (bidirectional) GRU layer
29
* A linear layer with softmax output
30
* Dropout regularization was used for the GRU layers
31
32
Since the model operates on segmented heartbeat samples, we can use a bidirectional RNN because the whole segment is available for processing at one time. It is also a more \"fair\" comparison with the CNN.
33
34
#### Bayesian Model
35
36
This model used the same network architecture as the convolutional (CNN) model above. However, the weights were stochastic, and posterior distributions of weights were trained using the Flipout method [\(Wen, Vicol, Ba, Tran, \& Grosse, 2018\)](https://arxiv.org/abs/1803.04386).
37
38
### Training
39
The CNN and RNN models were trained for 8000 parameter updates with a mini-batch size of 200 using the Adam optimizer with exponential learning rate decay. See notebook for parameter values. The RNN model took about 10x longer to train (wall time) than the CNN model.
40
41
The BNN model was trained for 3.5M parameter updates with a mini-batch size of 125 using the Adam optimizer with a fixed learning rate. The KL-divergence loss was annealed according to the [TensorFlow Probability example scheme](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/cifar10_bnn.py). See notebook for parameter values.
42
43
## Results
44
All models exhibited sufficient capacity to learn the training distribution with high accuracy. The error rates for all models were highest for the classes with the fewest examples. Collecting more data for the S- and F-type arrhythmias would likely increase the overall accuracy of the trained models.
45
46
In contrast with [Kachuee, Fazeli, & Sarrafzadeh \(2018\)](https://arxiv.org/pdf/1805.00794.pdf), I chose to upsample the under-represented classes rather than augment data as we do not have a physiologically valid generative model for heartbeats. Kachuee _et al._ also used augmented data as part of their test set without justification and I did not. As a consequence, my test set is much smaller. That said, my results for the convolutional model appear to be consistent with theirs.
47
48
#### Convolutional Model
49
```
50
       class  precision    recall  f1-score   support
51
52
           0       0.88      0.98      0.92       100
53
           1       0.98      0.91      0.94       100
54
           2       0.91      0.97      0.94       100
55
           3       0.98      0.87      0.92       100
56
           4       1.00      0.99      0.99       100
57
58
   micro avg       0.94      0.94      0.94       500
59
   macro avg       0.95      0.94      0.94       500
60
weighted avg       0.95      0.94      0.94       500
61
 ```
62
Confusion Matrix
63
64
 ![alt text](https://github.com/dave-fernandes/ECGClassifier/blob/master/images/CM-CNN.png "Confusion matrix for CNN classifier.")
65
66
#### Recurrent Model
67
```
68
       class  precision    recall  f1-score   support
69
70
           0       0.84      0.97      0.90       100
71
           1       0.98      0.89      0.93       100
72
           2       0.91      0.92      0.92       100
73
           3       0.98      0.89      0.93       100
74
           4       0.97      0.99      0.98       100
75
76
   micro avg       0.93      0.93      0.93       500
77
   macro avg       0.94      0.93      0.93       500
78
weighted avg       0.94      0.93      0.93       500
79
 ```
80
Confusion Matrix
81
82
 ![alt text](https://github.com/dave-fernandes/ECGClassifier/blob/master/images/CM-RNN.png "Confusion matrix for RNN classifier.")
83
84
#### Bayesian Model
85
For the Bayesian model, I obtained a Monte Carlo estimate for the most probable class. This class was then evaluated as above based on precision, recall, and the confusion matrix.
86
87
```
88
       class  precision    recall  f1-score   support
89
90
           0       0.88      0.98      0.92       100
91
           1       0.97      0.91      0.94       100
92
           2       0.92      0.98      0.95       100
93
           3       0.99      0.88      0.93       100
94
           4       1.00      0.99      0.99       100
95
96
   micro avg       0.95      0.95      0.95       500
97
   macro avg       0.95      0.95      0.95       500
98
weighted avg       0.95      0.95      0.95       500
99
 ```
100
Confusion Matrix
101
102
 ![alt text](https://github.com/dave-fernandes/ECGClassifier/blob/master/images/CM-BNN.png "Confusion matrix for BNN classifier.")
103
104
## Discussion
105
#### CNN versus RNN
106
The CNN model has 53,957 parameters and the RNN model has 240,293. Moreover, the serial nature of the RNN causes it to be less parallelizable than the CNN. Given that the CNN is slightly more accurate than the RNN, it provides an all-around better solution.
107
108
#### Maximum Likelihood versus Bayesian Estimate
109
The Bayesian model has slightly better performance than the standard CNN model with its maximum likelihood estimate (MLE). However, due to the small test-set size, this difference in performance may not be statistically significant. Still, the KL-divergence term in the loss for the Bayesian model should have a regularizing effect and allow the BNN model to generalize better.
110
111
#### Probability Estimation
112
The Bayesian neural network lore states that Bayesian networks produce better probability estimates than their standard \(maximum likelihood\) NN counterparts. We can check this by comparing the accuracy of the softmax \"probability\" estimate in the standard CNN model with the accuracy of the Monte Carlo probability estimate from the Bayesian network.
113
114
Fraction of correct CNN classifications versus softmax \"probability\" estimate, binned by estimate value. Error bars are 68% binomial confidence limits. The one-to-one line is the expected value if the estimate were perfect.
115
116
 ![alt text](https://github.com/dave-fernandes/ECGClassifier/blob/master/images/CNN-probability.png "Probability estimates for CNN classifier.")
117
118
Fraction of correct Bayesian network classifications versus Monte Carlo probability estimate, binned by estimate value. Error bars are 68% binomial confidence limits. The one-to-one line is the expected value if the estimate were perfect.
119
120
 ![alt text](https://github.com/dave-fernandes/ECGClassifier/blob/master/images/BNN-probability.png "Probability estimates for Bayesian classifier.")
121
 
122
 It is clear from the plots that the standard \(maximum likelihood\) network is estimating probability at least as well as the Bayesian network.
123
124
## Files
125
* `PreprocessECG.ipynb` is a Jupyter notebook used to format and balance the data.
126
* `ClassifyECG.ipynb` is a Jupyter notebook containing the CNN and RNN classification models, as well as training and evaluation code.
127
* `BayesClassifierECG.ipynb` is a Jupyter notebook containing the Bayesian classification model, as well as training and evaluation code.
128
* `ECG.xcodeproj` is an Xcode 11 project file that builds the Swift source from the `ECG` subdirectory to train the CNN model.
129
130
## Implementation Notes
131
* Python implementation tested with Python 3.6.7, TensorFlow 1.13.1, and TensorFlow Probability 0.6.0
132
* Swift implementation tested in Xcode 11 with Swift for Tensorflow toolchain 0.4.0