Switch to unified view

a/README.md b/README.md
1
# [torch_ecg](https://github.com/DeepPSP/torch_ecg/)
1
# [torch_ecg](https://github.com/DeepPSP/torch_ecg/)
2
2
3
[![pytest](https://github.com/DeepPSP/torch_ecg/actions/workflows/run-pytest.yml/badge.svg?branch=dev)](https://github.com/DeepPSP/torch_ecg/actions/workflows/run-pytest.yml)
3
[![pytest](https://github.com/DeepPSP/torch_ecg/actions/workflows/run-pytest.yml/badge.svg?branch=dev)](https://github.com/DeepPSP/torch_ecg/actions/workflows/run-pytest.yml)
4
[![codeql](https://github.com/DeepPSP/torch_ecg/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/DeepPSP/torch_ecg/actions/workflows/codeql-analysis.yml)
4
[![codeql](https://github.com/DeepPSP/torch_ecg/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/DeepPSP/torch_ecg/actions/workflows/codeql-analysis.yml)
5
[![formatting](https://github.com/DeepPSP/torch_ecg/actions/workflows/check-formatting.yml/badge.svg)](https://github.com/DeepPSP/torch_ecg/actions/workflows/check-formatting.yml)
5
[![formatting](https://github.com/DeepPSP/torch_ecg/actions/workflows/check-formatting.yml/badge.svg)](https://github.com/DeepPSP/torch_ecg/actions/workflows/check-formatting.yml)
6
[![codecov](https://codecov.io/gh/DeepPSP/torch_ecg/branch/master/graph/badge.svg?token=9YOPZ8GREA)](https://codecov.io/gh/DeepPSP/torch_ecg)
6
[![codecov](https://codecov.io/gh/DeepPSP/torch_ecg/branch/master/graph/badge.svg?token=9YOPZ8GREA)](https://codecov.io/gh/DeepPSP/torch_ecg)
7
[![PyPI](https://img.shields.io/pypi/v/torch-ecg?style=flat-square)](https://pypi.org/project/torch-ecg/)
7
[![PyPI](https://img.shields.io/pypi/v/torch-ecg?style=flat-square)](https://pypi.org/project/torch-ecg/)
8
[![DOI](https://img.shields.io/badge/DOI-10.1088%2F1361--6579%2Fac9451-informational?style=flat-square)](https://doi.org/10.1088/1361-6579/ac9451)
8
[![DOI](https://img.shields.io/badge/DOI-10.1088%2F1361--6579%2Fac9451-informational?style=flat-square)](https://doi.org/10.1088/1361-6579/ac9451)
9
[![zenodo](https://zenodo.org/badge/298482237.svg)](https://zenodo.org/badge/latestdoi/298482237)
9
[![zenodo](https://zenodo.org/badge/298482237.svg)](https://zenodo.org/badge/latestdoi/298482237)
10
[![downloads](https://img.shields.io/pypi/dm/torch-ecg?style=flat-square)](https://pypistats.org/packages/torch-ecg)
10
[![downloads](https://img.shields.io/pypi/dm/torch-ecg?style=flat-square)](https://pypistats.org/packages/torch-ecg)
11
[![license](https://img.shields.io/github/license/DeepPSP/torch_ecg?style=flat-square)](https://github.com/DeepPSP/torch_ecg/blob/master/LICENSE)
11
[![license](https://img.shields.io/github/license/DeepPSP/torch_ecg?style=flat-square)](https://github.com/DeepPSP/torch_ecg/blob/master/LICENSE)
12
12
13
ECG Deep Learning Framework Implemented using PyTorch.
13
ECG Deep Learning Framework Implemented using PyTorch.
14
14
15
Documentation (under development):
15
Documentation (under development):
16
16
17
- [GitHub Pages](https://deeppsp.github.io/torch_ecg/)  [![gh-page status](https://github.com/DeepPSP/torch_ecg/actions/workflows/docs-publish.yml/badge.svg?branch=doc)](https://github.com/DeepPSP/torch_ecg/actions/workflows/docs-publish.yml)
17
- [GitHub Pages](https://deeppsp.github.io/torch_ecg/)  [![gh-page status](https://github.com/DeepPSP/torch_ecg/actions/workflows/docs-publish.yml/badge.svg?branch=doc)](https://github.com/DeepPSP/torch_ecg/actions/workflows/docs-publish.yml)
18
- [Read the Docs](http://torch-ecg.rtfd.io/)  [![RTD status](https://readthedocs.org/projects/torch-ecg/badge/?version=latest)](https://torch-ecg.readthedocs.io/en/latest/?badge=latest)
18
- [Read the Docs](http://torch-ecg.rtfd.io/)  [![RTD status](https://readthedocs.org/projects/torch-ecg/badge/?version=latest)](https://torch-ecg.readthedocs.io/en/latest/?badge=latest)
19
- [latest version](https://deep-psp.tech/torch-ecg-docs-dev/)
19
- [latest version](https://deep-psp.tech/torch-ecg-docs-dev/)
20
20
21
The system design is depicted as follows
21
The system design is depicted as follows
22
22
23
<!-- ![system_design](docs/source/_static/images/system_design.jpg) -->
23
24
<p align="middle">
25
  <img src="/docs/source/_static/images/system_design.jpg" width="80%" />
26
</p>
27
28
<!-- toc -->
24
<!-- toc -->
29
25
30
- [Installation](#installation)
26
- [Installation](#installation)
31
- [Main Modules](#main-modules)
27
- [Main Modules](#main-modules)
32
  - [Augmenters](#augmenters)
28
  - [Augmenters](#augmenters)
33
  - [Preprocessors](#preprocessors)
29
  - [Preprocessors](#preprocessors)
34
  - [Databases](#databases)
30
  - [Databases](#databases)
35
  - [Implemented Neural Network Architectures](#implemented-neural-network-architectures)
31
  - [Implemented Neural Network Architectures](#implemented-neural-network-architectures)
36
    - [Quick Example](#quick-example)
32
    - [Quick Example](#quick-example)
37
    - [Custom Model](#custom-model)
33
    - [Custom Model](#custom-model)
38
  - [CNN Backbones](#cnn-backbones)
34
  - [CNN Backbones](#cnn-backbones)
39
    - [Implemented](#implemented)
35
    - [Implemented](#implemented)
40
    - [Ongoing](#ongoing)
36
    - [Ongoing](#ongoing)
41
    - [TODO](#todo)
37
    - [TODO](#todo)
42
  - [Components](#components)
38
  - [Components](#components)
43
    - [Loggers](#loggers)
39
    - [Loggers](#loggers)
44
    - [Outputs](#outputs)
40
    - [Outputs](#outputs)
45
    - [Metrics](#metrics)
41
    - [Metrics](#metrics)
46
    - [Trainer](#trainer)
42
    - [Trainer](#trainer)
47
- [Other Useful Tools](#other-useful-tools)
43
- [Other Useful Tools](#other-useful-tools)
48
  - [R peaks detection algorithms](#r-peaks-detection-algorithms)
44
  - [R peaks detection algorithms](#r-peaks-detection-algorithms)
49
- [Usage Examples](#usage-examples)
45
- [Usage Examples](#usage-examples)
50
- [CAUTION](#caution)
46
- [CAUTION](#caution)
51
- [Work in progress](#work-in-progress)
47
- [Work in progress](#work-in-progress)
52
- [Citation](#citation)
48
- [Citation](#citation)
53
- [Thanks](#thanks)
49
- [Thanks](#thanks)
54
- [Change Log](CHANGELOG.rst)
50
- [Change Log](CHANGELOG.rst)
55
51
56
<!-- tocstop -->
52
<!-- tocstop -->
57
53
58
## Installation
54
## Installation
59
55
60
`torch_ecg` requires Python 3.6+ and is available through pip:
56
`torch_ecg` requires Python 3.6+ and is available through pip:
61
57
62
```bash
58
```bash
63
python -m pip install torch-ecg
59
python -m pip install torch-ecg
64
```
60
```
65
61
66
One can download the development version hosted at [GitHub](https://github.com/DeepPSP/torch_ecg/) via
62
One can download the development version hosted at [GitHub](https://github.com/DeepPSP/torch_ecg/) via
67
63
68
```bash
64
```bash
69
git clone https://github.com/DeepPSP/torch_ecg.git
65
git clone https://github.com/DeepPSP/torch_ecg.git
70
cd torch_ecg
66
cd torch_ecg
71
python -m pip install .
67
python -m pip install .
72
```
68
```
73
69
74
or use pip directly via
70
or use pip directly via
75
71
76
```bash
72
```bash
77
python -m pip install git+https://github.com/DeepPSP/torch_ecg.git
73
python -m pip install git+https://github.com/DeepPSP/torch_ecg.git
78
```
74
```
79
75
80
## Main Modules
76
## Main Modules
81
77
82
### [Augmenters](torch_ecg/augmenters)
78
### [Augmenters](torch_ecg/augmenters)
83
79
84
<details>
80
<details>
85
<summary>Click to expand!</summary>
81
<summary>Click to expand!</summary>
86
82
87
Augmenters are classes (subclasses of `torch` `Module`) that perform data augmentation in a uniform way and are managed by the [`AugmenterManager`](torch_ecg/augmenters/augmenter_manager.py) (also a subclass of `torch` `Module`). Augmenters and the manager share a common signature of the `formward` method:
83
Augmenters are classes (subclasses of `torch` `Module`) that perform data augmentation in a uniform way and are managed by the [`AugmenterManager`](torch_ecg/augmenters/augmenter_manager.py) (also a subclass of `torch` `Module`). Augmenters and the manager share a common signature of the `formward` method:
88
84
89
```python
85
```python
90
forward(self, sig:Tensor, label:Optional[Tensor]=None, *extra_tensors:Sequence[Tensor], **kwargs:Any) -> Tuple[Tensor, ...]:
86
forward(self, sig:Tensor, label:Optional[Tensor]=None, *extra_tensors:Sequence[Tensor], **kwargs:Any) -> Tuple[Tensor, ...]:
91
```
87
```
92
88
93
The following augmenters are implemented:
89
The following augmenters are implemented:
94
90
95
1. baseline wander (adding sinusoidal and gaussian noises)
91
1. baseline wander (adding sinusoidal and gaussian noises)
96
2. cutmix
92
2. cutmix
97
3. mixup
93
3. mixup
98
4. random flip
94
4. random flip
99
5. random masking
95
5. random masking
100
6. random renormalize
96
6. random renormalize
101
7. stretch-or-compress (scaling)
97
7. stretch-or-compress (scaling)
102
8. label smooth (not actually for data augmentation, but has simimlar behavior)
98
8. label smooth (not actually for data augmentation, but has simimlar behavior)
103
99
104
Usage example (this example uses all augmenters except cutmix, each with default config):
100
Usage example (this example uses all augmenters except cutmix, each with default config):
105
101
106
```python
102
```python
107
import torch
103
import torch
108
from torch_ecg.cfg import CFG
104
from torch_ecg.cfg import CFG
109
from torch_ecg.augmenters import AugmenterManager
105
from torch_ecg.augmenters import AugmenterManager
110
106
111
config = CFG(
107
config = CFG(
112
    random=False,
108
    random=False,
113
    fs=500,
109
    fs=500,
114
    baseline_wander={},
110
    baseline_wander={},
115
    label_smooth={},
111
    label_smooth={},
116
    mixup={},
112
    mixup={},
117
    random_flip={},
113
    random_flip={},
118
    random_masking={},
114
    random_masking={},
119
    random_renormalize={},
115
    random_renormalize={},
120
    stretch_compress={},
116
    stretch_compress={},
121
)
117
)
122
am = AugmenterManager.from_config(config)
118
am = AugmenterManager.from_config(config)
123
sig, label, mask = torch.rand(2,12,5000), torch.rand(2,26), torch.rand(2,5000,1)
119
sig, label, mask = torch.rand(2,12,5000), torch.rand(2,26), torch.rand(2,5000,1)
124
sig, label, mask = am(sig, label, mask)
120
sig, label, mask = am(sig, label, mask)
125
```
121
```
126
122
127
Augmenters can be stochastic along the batch dimension and (or) the channel dimension (ref. the `get_indices` method of the [`Augmenter`](torch_ecg/augmenters/base.py) base class).
123
Augmenters can be stochastic along the batch dimension and (or) the channel dimension (ref. the `get_indices` method of the [`Augmenter`](torch_ecg/augmenters/base.py) base class).
128
124
129
:point_right: [Back to TOC](#torch_ecg)
125
:point_right: [Back to TOC](#torch_ecg)
130
126
131
</details>
127
</details>
132
128
133
### [Preprocessors](torch_ecg/preprocessors)
129
### [Preprocessors](torch_ecg/preprocessors)
134
130
135
<details>
131
<details>
136
<summary>Click to expand!</summary>
132
<summary>Click to expand!</summary>
137
133
138
Also [preprecessors](torch_ecg/_preprocessors) acting on `numpy` `array`s. Similarly, preprocessors are monitored by a manager
134
Also [preprecessors](torch_ecg/_preprocessors) acting on `numpy` `array`s. Similarly, preprocessors are monitored by a manager
139
135
140
```python
136
```python
141
import torch
137
import torch
142
from torch_ecg.cfg import CFG
138
from torch_ecg.cfg import CFG
143
from torch_ecg._preprocessors import PreprocManager
139
from torch_ecg._preprocessors import PreprocManager
144
140
145
config = CFG(
141
config = CFG(
146
    random=False,
142
    random=False,
147
    resample={"fs": 500},
143
    resample={"fs": 500},
148
    bandpass={},
144
    bandpass={},
149
    normalize={},
145
    normalize={},
150
)
146
)
151
ppm = PreprocManager.from_config(config)
147
ppm = PreprocManager.from_config(config)
152
sig = torch.rand(12,80000).numpy()
148
sig = torch.rand(12,80000).numpy()
153
sig, fs = ppm(sig, 200)
149
sig, fs = ppm(sig, 200)
154
```
150
```
155
151
156
The following preprocessors are implemented
152
The following preprocessors are implemented
157
153
158
1. baseline removal (detrend)
154
1. baseline removal (detrend)
159
2. normalize (z-score, min-max, naïve)
155
2. normalize (z-score, min-max, naïve)
160
3. bandpass
156
3. bandpass
161
4. resample
157
4. resample
162
158
163
For more examples, see the [README file](torch_ecg/preprocessors/README.md)) of the `preprecessors` module.
159
For more examples, see the [README file](torch_ecg/preprocessors/README.md)) of the `preprecessors` module.
164
160
165
:point_right: [Back to TOC](#torch_ecg)
161
:point_right: [Back to TOC](#torch_ecg)
166
162
167
</details>
163
</details>
168
164
169
### [Databases](torch_ecg/databases)
165
### [Databases](torch_ecg/databases)
170
166
171
<details>
167
<details>
172
<summary>Click to expand!</summary>
168
<summary>Click to expand!</summary>
173
169
174
This module include classes that manipulate the io of the ECG signals and labels in an ECG database, and maintains metadata (statistics, paths, plots, list of records, etc.) of it. This module is migrated and improved from [DeepPSP/database_reader](https://github.com/DeepPSP/database_reader)
170
This module include classes that manipulate the io of the ECG signals and labels in an ECG database, and maintains metadata (statistics, paths, plots, list of records, etc.) of it. This module is migrated and improved from [DeepPSP/database_reader](https://github.com/DeepPSP/database_reader)
175
171
176
After migration, all should be tested again, the progression:
172
After migration, all should be tested again, the progression:
177
173
178
| Database      | Source                                                           | Tested             |
174
| Database      | Source                                                           | Tested             |
179
| ------------- | ---------------------------------------------------------------- | ------------------ |
175
| ------------- | ---------------------------------------------------------------- | ------------------ |
180
| AFDB          | [PhysioNet](https://physionet.org/content/afdb/1.0.0/)           | :heavy_check_mark: |
176
| AFDB          | [PhysioNet](https://physionet.org/content/afdb/1.0.0/)           | :heavy_check_mark: |
181
| ApneaECG      | [PhysioNet](https://physionet.org/content/apnea-ecg/1.0.0/)      | :x:                |
177
| ApneaECG      | [PhysioNet](https://physionet.org/content/apnea-ecg/1.0.0/)      | :x:                |
182
| CinC2017      | [PhysioNet](https://physionet.org/content/challenge-2017/1.0.0/) | :x:                |
178
| CinC2017      | [PhysioNet](https://physionet.org/content/challenge-2017/1.0.0/) | :x:                |
183
| CinC2018      | [PhysioNet](https://physionet.org/content/challenge-2018/1.0.0/) | :x:                |
179
| CinC2018      | [PhysioNet](https://physionet.org/content/challenge-2018/1.0.0/) | :x:                |
184
| CinC2020      | [PhysioNet](https://physionet.org/content/challenge-2020/1.0.1/) | :heavy_check_mark: |
180
| CinC2020      | [PhysioNet](https://physionet.org/content/challenge-2020/1.0.1/) | :heavy_check_mark: |
185
| CinC2021      | [PhysioNet](https://physionet.org/content/challenge-2021/1.0.2/) | :heavy_check_mark: |
181
| CinC2021      | [PhysioNet](https://physionet.org/content/challenge-2021/1.0.2/) | :heavy_check_mark: |
186
| LTAFDB        | [PhysioNet](https://physionet.org/content/ltafdb/1.0.0/)         | :x:                |
182
| LTAFDB        | [PhysioNet](https://physionet.org/content/ltafdb/1.0.0/)         | :x:                |
187
| LUDB          | [PhysioNet](https://physionet.org/content/ludb/1.0.1/)           | :heavy_check_mark: |
183
| LUDB          | [PhysioNet](https://physionet.org/content/ludb/1.0.1/)           | :heavy_check_mark: |
188
| MITDB         | [PhysioNet](https://physionet.org/content/mitdb/1.0.0/)          | :heavy_check_mark: |
184
| MITDB         | [PhysioNet](https://physionet.org/content/mitdb/1.0.0/)          | :heavy_check_mark: |
189
| SHHS          | [NSRR](https://sleepdata.org/datasets/shhs)                      | :x:                |
185
| SHHS          | [NSRR](https://sleepdata.org/datasets/shhs)                      | :x:                |
190
| CPSC2018      | [CPSC](http://2018.icbeb.org/Challenge.html)                     | :heavy_check_mark: |
186
| CPSC2018      | [CPSC](http://2018.icbeb.org/Challenge.html)                     | :heavy_check_mark: |
191
| CPSC2019      | [CPSC](http://2019.icbeb.org/Challenge.html)                     | :heavy_check_mark: |
187
| CPSC2019      | [CPSC](http://2019.icbeb.org/Challenge.html)                     | :heavy_check_mark: |
192
| CPSC2020      | [CPSC](http://2020.icbeb.org/CSPC2020)                           | :heavy_check_mark: |
188
| CPSC2020      | [CPSC](http://2020.icbeb.org/CSPC2020)                           | :heavy_check_mark: |
193
| CPSC2021      | [CPSC](http://2021.icbeb.org/CPSC2021)                           | :heavy_check_mark: |
189
| CPSC2021      | [CPSC](http://2021.icbeb.org/CPSC2021)                           | :heavy_check_mark: |
194
| SPH           | [Figshare](https://doi.org/10.6084/m9.figshare.c.5779802.v1)     | :heavy_check_mark: |
190
| SPH           | [Figshare](https://doi.org/10.6084/m9.figshare.c.5779802.v1)     | :heavy_check_mark: |
195
191
196
NOTE that these classes should not be confused with a `torch` `Dataset`, which is strongly related to the task (or the model). However, one can build `Dataset`s based on these classes, for example the [`Dataset`](benchmarks/train_hybrid_cpsc2021/dataset.py) for the The 4th China Physiological Signal Challenge 2021 (CPSC2021).
192
NOTE that these classes should not be confused with a `torch` `Dataset`, which is strongly related to the task (or the model). However, one can build `Dataset`s based on these classes, for example the [`Dataset`](benchmarks/train_hybrid_cpsc2021/dataset.py) for the The 4th China Physiological Signal Challenge 2021 (CPSC2021).
197
193
198
One can use the built-in `Dataset`s in [`torch_ecg.databases.datasets`](torch_ecg/databases/datasets) as follows
194
One can use the built-in `Dataset`s in [`torch_ecg.databases.datasets`](torch_ecg/databases/datasets) as follows
199
195
200
```python
196
```python
201
from torch_ecg.databases.datasets.cinc2021 import CINC2021Dataset, CINC2021TrainCfg
197
from torch_ecg.databases.datasets.cinc2021 import CINC2021Dataset, CINC2021TrainCfg
202
config = deepcopy(CINC2021TrainCfg)
198
config = deepcopy(CINC2021TrainCfg)
203
config.db_dir = "some/path/to/db"
199
config.db_dir = "some/path/to/db"
204
dataset = CINC2021Dataset(config, training=True, lazy=False)
200
dataset = CINC2021Dataset(config, training=True, lazy=False)
205
```
201
```
206
202
207
:point_right: [Back to TOC](#torch_ecg)
203
:point_right: [Back to TOC](#torch_ecg)
208
204
209
</details>
205
</details>
210
206
211
### [Implemented Neural Network Architectures](torch_ecg/models)
207
### [Implemented Neural Network Architectures](torch_ecg/models)
212
208
213
<details>
209
<details>
214
<summary>Click to expand!</summary>
210
<summary>Click to expand!</summary>
215
211
216
1. CRNN, both for classification and sequence tagging (segmentation)
212
1. CRNN, both for classification and sequence tagging (segmentation)
217
2. U-Net
213
2. U-Net
218
3. RR-LSTM
214
3. RR-LSTM
219
215
220
A typical signature of the instantiation (`__init__`) function of a model is as follows
216
A typical signature of the instantiation (`__init__`) function of a model is as follows
221
217
222
```python
218
```python
223
__init__(self, classes:Sequence[str], n_leads:int, config:Optional[CFG]=None, **kwargs:Any) -> None
219
__init__(self, classes:Sequence[str], n_leads:int, config:Optional[CFG]=None, **kwargs:Any) -> None
224
```
220
```
225
221
226
if a `config` is not specified, then the default config will be used (stored in the [`model_configs`](torch_ecg/model_configs) module).
222
if a `config` is not specified, then the default config will be used (stored in the [`model_configs`](torch_ecg/model_configs) module).
227
223
228
#### Quick Example
224
#### Quick Example
229
225
230
A quick example is as follows:
226
A quick example is as follows:
231
227
232
```python
228
```python
233
import torch
229
import torch
234
from torch_ecg.utils.utils_nn import adjust_cnn_filter_lengths
230
from torch_ecg.utils.utils_nn import adjust_cnn_filter_lengths
235
from torch_ecg.model_configs import ECG_CRNN_CONFIG
231
from torch_ecg.model_configs import ECG_CRNN_CONFIG
236
from torch_ecg.models.ecg_crnn import ECG_CRNN
232
from torch_ecg.models.ecg_crnn import ECG_CRNN
237
233
238
config = adjust_cnn_filter_lengths(ECG_CRNN_CONFIG, fs=400)
234
config = adjust_cnn_filter_lengths(ECG_CRNN_CONFIG, fs=400)
239
# change the default CNN backbone
235
# change the default CNN backbone
240
# bottleneck with global context attention variant of Nature Communications ResNet
236
# bottleneck with global context attention variant of Nature Communications ResNet
241
config.cnn.name="resnet_nature_comm_bottle_neck_gc"
237
config.cnn.name="resnet_nature_comm_bottle_neck_gc"
242
238
243
classes = ["NSR", "AF", "PVC", "SPB"]
239
classes = ["NSR", "AF", "PVC", "SPB"]
244
n_leads = 12
240
n_leads = 12
245
model = ECG_CRNN(classes, n_leads, config)
241
model = ECG_CRNN(classes, n_leads, config)
246
242
247
model(torch.rand(2, 12, 4000))  # signal length 4000, batch size 2
243
model(torch.rand(2, 12, 4000))  # signal length 4000, batch size 2
248
```
244
```
249
245
250
Then a model for the classification of 4 classes, namely "NSR", "AF", "PVC", "SPB", on 12-lead ECGs is created. One can check the size of a model, in terms of the number of parameters via
246
Then a model for the classification of 4 classes, namely "NSR", "AF", "PVC", "SPB", on 12-lead ECGs is created. One can check the size of a model, in terms of the number of parameters via
251
247
252
```python
248
```python
253
model.module_size
249
model.module_size
254
```
250
```
255
251
256
or in terms of memory consumption via
252
or in terms of memory consumption via
257
253
258
```python
254
```python
259
model.module_size_
255
model.module_size_
260
```
256
```
261
257
262
#### Custom Model
258
#### Custom Model
263
259
264
One can adjust the configs to create a custom model. For example, the building blocks of the 4 stages of a `TResNet` backbone are `basic`, `basic`, `bottleneck`, `bottleneck`. If one wants to change the second block to be a `bottleneck` block with sequeeze and excitation (`SE`) attention, then
260
One can adjust the configs to create a custom model. For example, the building blocks of the 4 stages of a `TResNet` backbone are `basic`, `basic`, `bottleneck`, `bottleneck`. If one wants to change the second block to be a `bottleneck` block with sequeeze and excitation (`SE`) attention, then
265
261
266
```python
262
```python
267
from copy import deepcopy
263
from copy import deepcopy
268
264
269
from torch_ecg.models.ecg_crnn import ECG_CRNN
265
from torch_ecg.models.ecg_crnn import ECG_CRNN
270
from torch_ecg.model_configs import (
266
from torch_ecg.model_configs import (
271
    ECG_CRNN_CONFIG,
267
    ECG_CRNN_CONFIG,
272
    tresnetF, resnet_bottle_neck_se,
268
    tresnetF, resnet_bottle_neck_se,
273
)
269
)
274
270
275
my_resnet = deepcopy(tresnetP)
271
my_resnet = deepcopy(tresnetP)
276
my_resnet.building_block[1] = "bottleneck"
272
my_resnet.building_block[1] = "bottleneck"
277
my_resnet.block[1] = resnet_bottle_neck_se
273
my_resnet.block[1] = resnet_bottle_neck_se
278
```
274
```
279
275
280
The convolutions in a `TResNet` are anti-aliasing convolutions, if one wants further to change the convolutions to normal convolutions, then
276
The convolutions in a `TResNet` are anti-aliasing convolutions, if one wants further to change the convolutions to normal convolutions, then
281
277
282
```python
278
```python
283
for b in my_resnet.block:
279
for b in my_resnet.block:
284
    b.conv_type = None
280
    b.conv_type = None
285
```
281
```
286
282
287
or change them to separable convolutions via
283
or change them to separable convolutions via
288
284
289
```python
285
```python
290
for b in my_resnet.block:
286
for b in my_resnet.block:
291
    b.conv_type = "separable"
287
    b.conv_type = "separable"
292
```
288
```
293
289
294
Finally, replace the default CNN backbone via
290
Finally, replace the default CNN backbone via
295
291
296
```python
292
```python
297
my_model_config = deepcopy(ECG_CRNN_CONFIG)
293
my_model_config = deepcopy(ECG_CRNN_CONFIG)
298
my_model_config.cnn.name = "my_resnet"
294
my_model_config.cnn.name = "my_resnet"
299
my_model_config.cnn.my_resnet = my_resnet
295
my_model_config.cnn.my_resnet = my_resnet
300
296
301
model = ECG_CRNN(["NSR", "AF", "PVC", "SPB"], 12, my_model_config)
297
model = ECG_CRNN(["NSR", "AF", "PVC", "SPB"], 12, my_model_config)
302
```
298
```
303
299
304
:point_right: [Back to TOC](#torch_ecg)
300
:point_right: [Back to TOC](#torch_ecg)
305
301
306
</details>
302
</details>
307
303
308
### [CNN Backbones](torch_ecg/models/cnn)
304
### [CNN Backbones](torch_ecg/models/cnn)
309
305
310
<details>
306
<details>
311
<summary>Click to expand!</summary>
307
<summary>Click to expand!</summary>
312
308
313
#### Implemented
309
#### Implemented
314
310
315
1. VGG
311
1. VGG
316
2. ResNet (including vanilla ResNet, ResNet-B, ResNet-C, ResNet-D, ResNeXT, TResNet, [Stanford ResNet](https://github.com/awni/ecg), [Nature Communications ResNet](https://github.com/antonior92/automatic-ecg-diagnosis), etc.)
312
2. ResNet (including vanilla ResNet, ResNet-B, ResNet-C, ResNet-D, ResNeXT, TResNet, [Stanford ResNet](https://github.com/awni/ecg), [Nature Communications ResNet](https://github.com/antonior92/automatic-ecg-diagnosis), etc.)
317
3. MultiScopicNet (CPSC2019 SOTA)
313
3. MultiScopicNet (CPSC2019 SOTA)
318
4. DenseNet (CPSC2020 SOTA)
314
4. DenseNet (CPSC2020 SOTA)
319
5. Xception
315
5. Xception
320
316
321
In general, variants of ResNet are the most commonly used architectures, as can be inferred from [CinC2020](https://cinc.org/archives/2020/) and [CinC2021](https://cinc.org/archives/2021/).
317
In general, variants of ResNet are the most commonly used architectures, as can be inferred from [CinC2020](https://cinc.org/archives/2020/) and [CinC2021](https://cinc.org/archives/2021/).
322
318
323
#### Ongoing
319
#### Ongoing
324
320
325
1. MobileNet
321
1. MobileNet
326
2. DarkNet
322
2. DarkNet
327
3. EfficientNet
323
3. EfficientNet
328
324
329
#### TODO
325
#### TODO
330
326
331
1. HarDNet
327
1. HarDNet
332
2. HO-ResNet
328
2. HO-ResNet
333
3. U-Net++
329
3. U-Net++
334
4. U-Squared Net
330
4. U-Squared Net
335
5. etc.
331
5. etc.
336
332
337
More details and a list of references can be found in the [README file](torch_ecg/models/cnn/README.md) of this module.
333
More details and a list of references can be found in the [README file](torch_ecg/models/cnn/README.md) of this module.
338
334
339
:point_right: [Back to TOC](#torch_ecg)
335
:point_right: [Back to TOC](#torch_ecg)
340
336
341
</details>
337
</details>
342
338
343
### [Components](torch_ecg/components/)
339
### [Components](torch_ecg/components/)
344
340
345
<details>
341
<details>
346
<summary>Click to expand!</summary>
342
<summary>Click to expand!</summary>
347
343
348
This module consists of frequently used components such as loggers, trainers, etc.
344
This module consists of frequently used components such as loggers, trainers, etc.
349
345
350
#### [Loggers](torch_ecg/components/loggers.py)
346
#### [Loggers](torch_ecg/components/loggers.py)
351
347
352
Loggers including
348
Loggers including
353
349
354
1. CSV logger
350
1. CSV logger
355
2. text logger
351
2. text logger
356
3. tensorboard logger
352
3. tensorboard logger
357
are implemented and manipulated uniformly by a manager.
353
are implemented and manipulated uniformly by a manager.
358
354
359
#### [Outputs](torch_ecg/components/outputs.py)
355
#### [Outputs](torch_ecg/components/outputs.py)
360
356
361
The `Output` classes implemented in this module serve as containers for ECG downstream task model outputs, including
357
The `Output` classes implemented in this module serve as containers for ECG downstream task model outputs, including
362
358
363
- `ClassificationOutput`
359
- `ClassificationOutput`
364
- `MultiLabelClassificationOutput`
360
- `MultiLabelClassificationOutput`
365
- `SequenceTaggingOutput`
361
- `SequenceTaggingOutput`
366
- `WaveDelineationOutput`
362
- `WaveDelineationOutput`
367
- `RPeaksDetectionOutput`
363
- `RPeaksDetectionOutput`
368
364
369
each having some required fields (keys), and is able to hold an arbitrary number of custom fields. These classes are useful for the computation of metrics.
365
each having some required fields (keys), and is able to hold an arbitrary number of custom fields. These classes are useful for the computation of metrics.
370
366
371
#### [Metrics](torch_ecg/components/metrics.py)
367
#### [Metrics](torch_ecg/components/metrics.py)
372
368
373
This module has the following pre-defined (built-in) `Metrics` classes:
369
This module has the following pre-defined (built-in) `Metrics` classes:
374
370
375
- `ClassificationMetrics`
371
- `ClassificationMetrics`
376
- `RPeaksDetectionMetrics`
372
- `RPeaksDetectionMetrics`
377
- `WaveDelineationMetrics`
373
- `WaveDelineationMetrics`
378
374
379
These metrics are computed according to either [Wikipedia](https://en.wikipedia.org/wiki/Precision_and_recall), or some published literatures.
375
These metrics are computed according to either [Wikipedia](https://en.wikipedia.org/wiki/Precision_and_recall), or some published literatures.
380
376
381
#### [Trainer](torch_ecg/components/trainer.py)
377
#### [Trainer](torch_ecg/components/trainer.py)
382
378
383
An abstract base class `BaseTrainer` is implemented, in which some common steps in building a training pipeline (workflow) are impemented. A few task specific methods are assigned as `abstractmethod`s, for example the method
379
An abstract base class `BaseTrainer` is implemented, in which some common steps in building a training pipeline (workflow) are impemented. A few task specific methods are assigned as `abstractmethod`s, for example the method
384
380
385
```python
381
```python
386
evaluate(self, data_loader:DataLoader) -> Dict[str, float]
382
evaluate(self, data_loader:DataLoader) -> Dict[str, float]
387
```
383
```
388
384
389
for evaluation on the validation set during training and perhaps further for model selection and early stopping.
385
for evaluation on the validation set during training and perhaps further for model selection and early stopping.
390
386
391
:point_right: [Back to TOC](#torch_ecg)
387
:point_right: [Back to TOC](#torch_ecg)
392
388
393
</details>
389
</details>
394
390
395
:point_right: [Back to TOC](#torch_ecg)
391
:point_right: [Back to TOC](#torch_ecg)
396
392
397
## Other Useful Tools
393
## Other Useful Tools
398
394
399
<details>
395
<details>
400
<summary>Click to expand!</summary>
396
<summary>Click to expand!</summary>
401
397
402
### [R peaks detection algorithms](torch_ecg/utils/rpeaks.py)
398
### [R peaks detection algorithms](torch_ecg/utils/rpeaks.py)
403
399
404
This is a collection of traditional (non deep learning) algorithms for R peaks detection collected from [WFDB](https://github.com/MIT-LCP/wfdb-python) and [BioSPPy](https://github.com/PIA-Group/BioSPPy).
400
This is a collection of traditional (non deep learning) algorithms for R peaks detection collected from [WFDB](https://github.com/MIT-LCP/wfdb-python) and [BioSPPy](https://github.com/PIA-Group/BioSPPy).
405
401
406
:point_right: [Back to TOC](#torch_ecg)
402
:point_right: [Back to TOC](#torch_ecg)
407
403
408
</details>
404
</details>
409
405
410
## Usage Examples
406
## Usage Examples
411
407
412
<details>
408
<details>
413
<summary>Click to expand!</summary>
409
<summary>Click to expand!</summary>
414
410
415
See case studies in the [benchmarks folder](benchmarks/).
411
See case studies in the [benchmarks folder](benchmarks/).
416
412
417
a large part of the case studies are migrated from other DeepPSP repositories, some are implemented in the old fasion, being inconsistent with the new system architecture of `torch_ecg`, hence need updating and testing
413
a large part of the case studies are migrated from other DeepPSP repositories, some are implemented in the old fasion, being inconsistent with the new system architecture of `torch_ecg`, hence need updating and testing
418
414
419
| Benchmark                                      | Architecture              | Source                                                  | Finished           | Updated            | Tested             |
415
| Benchmark                                      | Architecture              | Source                                                  | Finished           | Updated            | Tested             |
420
| ---------------------------------------------- | ------------------------- | ------------------------------------------------------- | ------------------ | ------------------ | ------------------ |
416
| ---------------------------------------------- | ------------------------- | ------------------------------------------------------- | ------------------ | ------------------ | ------------------ |
421
| [CinC2020](benchmarks/train_crnn_cinc2020/)    | CRNN                      | [DeepPSP/cinc2020](https://github.com/DeepPSP/cinc2020) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
417
| [CinC2020](benchmarks/train_crnn_cinc2020/)    | CRNN                      | [DeepPSP/cinc2020](https://github.com/DeepPSP/cinc2020) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
422
| [CinC2021](benchmarks/train_crnn_cinc2021/)    | CRNN                      | [DeepPSP/cinc2021](https://github.com/DeepPSP/cinc2021) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
418
| [CinC2021](benchmarks/train_crnn_cinc2021/)    | CRNN                      | [DeepPSP/cinc2021](https://github.com/DeepPSP/cinc2021) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
423
| [CinC2022](benchmarks/train_mtl_cinc2022/)[^1] | Multi Task Learning (MTL) | [DeepPSP/cinc2022](https://github.com/DeepPSP/cinc2022) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
419
| [CinC2022](benchmarks/train_mtl_cinc2022/)[^1] | Multi Task Learning (MTL) | [DeepPSP/cinc2022](https://github.com/DeepPSP/cinc2022) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
424
| [CPSC2019](benchmarks/train_multi_cpsc2019/)   | SequenceTagging/U-Net     | NA                                                      | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
420
| [CPSC2019](benchmarks/train_multi_cpsc2019/)   | SequenceTagging/U-Net     | NA                                                      | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
425
| [CPSC2020](benchmarks/train_hybrid_cpsc2020/)  | CRNN/SequenceTagging      | [DeepPSP/cpsc2020](https://github.com/DeepPSP/cpsc2020) | :heavy_check_mark: | :x:                | :x:                |
421
| [CPSC2020](benchmarks/train_hybrid_cpsc2020/)  | CRNN/SequenceTagging      | [DeepPSP/cpsc2020](https://github.com/DeepPSP/cpsc2020) | :heavy_check_mark: | :x:                | :x:                |
426
| [CPSC2021](benchmarks/train_hybrid_cpsc2021/)  | CRNN/SequenceTagging/LSTM | [DeepPSP/cpsc2021](https://github.com/DeepPSP/cpsc2021) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
422
| [CPSC2021](benchmarks/train_hybrid_cpsc2021/)  | CRNN/SequenceTagging/LSTM | [DeepPSP/cpsc2021](https://github.com/DeepPSP/cpsc2021) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
427
| [LUDB](benchmarks/train_unet_ludb/)            | U-Net                     | NA                                                      | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
423
| [LUDB](benchmarks/train_unet_ludb/)            | U-Net                     | NA                                                      | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
428
424
429
[^1]: Although `CinC2022` dealt with acoustic cardiac signals (phonocardiogram, PCG), the tasks and signals can be treated similarly.
425
[^1]: Although `CinC2022` dealt with acoustic cardiac signals (phonocardiogram, PCG), the tasks and signals can be treated similarly.
430
426
431
Taking [CPSC2021](benchmarks/train_hybrid_cpsc2021) for example, the steps are
427
Taking [CPSC2021](benchmarks/train_hybrid_cpsc2021) for example, the steps are
432
428
433
1. Write a [`Dataset`](benchmarks/train_hybrid_cpsc2021/dataset.py) to fit the training data for the model(s) and the training workflow. Or directly use the built-in `Dataset`s in [`torch_ecg.databases.datasets`](torch_ecg/databases/datasets). In this example, 3 tasks are considered, 2 of which use a [`MaskedBCEWithLogitsLoss`](torch_ecg/models/loss.py) function, hence the `Dataset` produces an extra tensor for these 2 tasks
429
1. Write a [`Dataset`](benchmarks/train_hybrid_cpsc2021/dataset.py) to fit the training data for the model(s) and the training workflow. Or directly use the built-in `Dataset`s in [`torch_ecg.databases.datasets`](torch_ecg/databases/datasets). In this example, 3 tasks are considered, 2 of which use a [`MaskedBCEWithLogitsLoss`](torch_ecg/models/loss.py) function, hence the `Dataset` produces an extra tensor for these 2 tasks
434
430
435
    ```python
431
    ```python
436
    def __getitem__(self, index:int) -> Tuple[np.ndarray, ...]:
432
    def __getitem__(self, index:int) -> Tuple[np.ndarray, ...]:
437
        if self.lazy:
433
        if self.lazy:
438
            if self.task in ["qrs_detection"]:
434
            if self.task in ["qrs_detection"]:
439
                return self.fdr[index][:2]
435
                return self.fdr[index][:2]
440
            else:
436
            else:
441
                return self.fdr[index]
437
                return self.fdr[index]
442
        else:
438
        else:
443
            if self.task in ["qrs_detection"]:
439
            if self.task in ["qrs_detection"]:
444
                return self._all_data[index], self._all_labels[index]
440
                return self._all_data[index], self._all_labels[index]
445
            else:
441
            else:
446
                return self._all_data[index], self._all_labels[index], self._all_masks[index]
442
                return self._all_data[index], self._all_labels[index], self._all_masks[index]
447
    ```
443
    ```
448
444
449
2. Inherit a [base model](torch_ecg/models/ecg_seq_lab_net.py) to create [task specific models](benchmarks/train_hybrid_cpsc2021/model.py), along with [tailored model configs](benchmarks/train_hybrid_cpsc2021/cfg.py)
445
2. Inherit a [base model](torch_ecg/models/ecg_seq_lab_net.py) to create [task specific models](benchmarks/train_hybrid_cpsc2021/model.py), along with [tailored model configs](benchmarks/train_hybrid_cpsc2021/cfg.py)
450
3. Inherit the [`BaseTrainer`](torch_ecg/components/trainer.py) to build the [training pipeline](benchmarks/train_hybrid_cpsc2021/trainer.py), with the `abstractmethod`s (`_setup_dataloaders`, `run_one_step`, `evaluate`, `batch_dim`, etc.) implemented.
446
3. Inherit the [`BaseTrainer`](torch_ecg/components/trainer.py) to build the [training pipeline](benchmarks/train_hybrid_cpsc2021/trainer.py), with the `abstractmethod`s (`_setup_dataloaders`, `run_one_step`, `evaluate`, `batch_dim`, etc.) implemented.
451
447
452
:point_right: [Back to TOC](#torch_ecg)
448
:point_right: [Back to TOC](#torch_ecg)
453
449
454
</details>
450
</details>
455
451
456
## CAUTION
452
## CAUTION
457
453
458
For the most of the time, but not always, after updates, I will run the notebooks in the [benchmarks](benchmarks/) manually. If someone finds some bug, please raise an issue. The test workflow is to be enhanced and automated, see [this project](https://github.com/DeepPSP/torch_ecg/projects/8).
454
For the most of the time, but not always, after updates, I will run the notebooks in the [benchmarks](benchmarks/) manually. If someone finds some bug, please raise an issue. The test workflow is to be enhanced and automated, see [this project](https://github.com/DeepPSP/torch_ecg/projects/8).
459
455
460
:point_right: [Back to TOC](#torch_ecg)
456
:point_right: [Back to TOC](#torch_ecg)
461
457
462
## Work in progress
458
## Work in progress
463
459
464
See the [projects page](https://github.com/DeepPSP/torch_ecg/projects).
460
See the [projects page](https://github.com/DeepPSP/torch_ecg/projects).
465
461
466
:point_right: [Back to TOC](#torch_ecg)
462
:point_right: [Back to TOC](#torch_ecg)
467
463
468
## Citation
464
## Citation
469
465
470
```latex
466
```latex
471
@misc{torch_ecg,
467
@misc{torch_ecg,
472
      title = {{torch\_ecg: An ECG Deep Learning Framework Implemented using PyTorch}},
468
      title = {{torch\_ecg: An ECG Deep Learning Framework Implemented using PyTorch}},
473
     author = {WEN, Hao and KANG, Jingsu},
469
     author = {WEN, Hao and KANG, Jingsu},
474
        doi = {10.5281/ZENODO.6435048},
470
        doi = {10.5281/ZENODO.6435048},
475
        url = {https://zenodo.org/record/6435048},
471
        url = {https://zenodo.org/record/6435048},
476
  publisher = {Zenodo},
472
  publisher = {Zenodo},
477
       year = {2022},
473
       year = {2022},
478
  copyright = {{MIT License}}
474
  copyright = {{MIT License}}
479
}
475
}
480
@article{torch_ecg_paper,
476
@article{torch_ecg_paper,
481
      title = {{A Novel Deep Learning Package for Electrocardiography Research}},
477
      title = {{A Novel Deep Learning Package for Electrocardiography Research}},
482
     author = {Hao Wen and Jingsu Kang},
478
     author = {Hao Wen and Jingsu Kang},
483
    journal = {{Physiological Measurement}},
479
    journal = {{Physiological Measurement}},
484
        doi = {10.1088/1361-6579/ac9451},
480
        doi = {10.1088/1361-6579/ac9451},
485
       year = {2022},
481
       year = {2022},
486
      month = {11},
482
      month = {11},
487
  publisher = {{IOP Publishing}},
483
  publisher = {{IOP Publishing}},
488
     volume = {43},
484
     volume = {43},
489
     number = {11},
485
     number = {11},
490
      pages = {115006}
486
      pages = {115006}
491
}
487
}
492
```
488
```
493
489
494
:point_right: [Back to TOC](#torch_ecg)
490
:point_right: [Back to TOC](#torch_ecg)
495
491
496
## Thanks
492
## Thanks
497
493
498
Much is learned, especially the modular design, from the adversarial NLP library [`TextAttack`](https://github.com/QData/TextAttack) and from Hugging Face [`transformers`](https://github.com/huggingface/transformers).
494
Much is learned, especially the modular design, from the adversarial NLP library [`TextAttack`](https://github.com/QData/TextAttack) and from Hugging Face [`transformers`](https://github.com/huggingface/transformers).
499
495
500
:point_right: [Back to TOC](#torch_ecg)
496
:point_right: [Back to TOC](#torch_ecg)