863 lines (862 with data), 98.1 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "43d6de63",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-16T20:11:29.105213Z",
"start_time": "2022-05-16T20:11:27.031042Z"
}
},
"source": [
"## Binary classification example\n",
"\n",
"In this example we show how the ecgxai package can be used to easily build a classification system for atrial fibrillation. We train the model on the PTB-XL dataset and use 'double residual' convolution resnet architecture. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "fbd9c54b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:44:20.919791Z",
"start_time": "2022-05-17T11:44:19.992428Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"# We first import the required data utilities from the ecgxai package\n",
"from ecgxai.utils.dataset import PTBXLDataset\n",
"from ecgxai.utils.transforms import ApplyGain, ToTensor\n",
"\n",
"# We also import some additional utilities from other packages for additionally functionality\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.transforms import Compose"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d8e45a1",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:44:22.928971Z",
"start_time": "2022-05-17T11:44:20.921179Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 21837/21837 [00:00<00:00, 186045.27it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using downloaded and verified file: /workspace/misc/PTB_XL/PTB_XL.tar.gz\n",
"Trainset:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"-- Dataset distribution -- \n",
"Full size: 19653\n",
"[\"AF\"] - Num entries: 1356 (6.9%)\n",
"\n",
" Testset: \n",
"-- Dataset distribution -- \n",
"Full size: 2184\n",
"[\"AF\"] - Num entries: 158 (7.23%)\n"
]
}
],
"source": [
"# The PTBXL dataset class automatically downloads and extract the PTB-XL 12 lead dataset\n",
"dataset = PTBXLDataset(\n",
" # The path parameter defines where the data should be stored and were the data can be found in the future\n",
" path=\"/workspace/misc/PTB_XL\", \n",
" # We use the ApplyGain and ToTensor transformation (chained through the compose class)\n",
" # to format the data into the desired format. In the case of PTB-XL, all the input data\n",
" # is devided by 1000 in the ApplyGain transform to smoothen training (see the next cell for an example).\n",
" transform = Compose([ApplyGain(), ToTensor()]),\n",
" # The use_numpy parameter is set to True to speed up the loading of the data\n",
" use_numpy=True,\n",
" # As we are classifying 'Atrial fibrilation' in this example we tell the dataset to use the \n",
" # 'AF' label. \n",
" labels='AF'\n",
")\n",
"\n",
"# We then randomly split the dataset into a train and test set using a 90%-10% split\n",
"trainset, testset = dataset.train_test_split(ratio=0.1, shuffle=True)\n",
"\n",
"print(\"Trainset:\")\n",
"trainset.print_stats()\n",
"\n",
"print(\"\\n Testset: \")\n",
"testset.print_stats()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "975b3b21",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:44:22.950649Z",
"start_time": "2022-05-17T11:44:22.931519Z"
}
},
"outputs": [],
"source": [
"# The train and testset are then supplied to a pytorch dataloader which we can use to train the model\n",
"train_loader = DataLoader(\n",
" trainset,\n",
" batch_size=64,\n",
" num_workers=8,\n",
" shuffle=True\n",
")\n",
"\n",
"test_loader = DataLoader(\n",
" testset,\n",
" batch_size=64,\n",
" num_workers=8\n",
")"
]
},
{
"cell_type": "markdown",
"id": "12d81e5c",
"metadata": {},
"source": [
"### Checking the data\n",
"\n",
"The cell below plots the first lead of a sample from the dataset which we get using the __\\_\\_getitem\\_\\_()__ function which returns a dictionary. The raw ecg is stored under the 'waveform' key. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3b14d826",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:44:23.156470Z",
"start_time": "2022-05-17T11:44:22.951914Z"
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['waveform', 'samplebase', 'gain', 'id', 'truebaseline_0', 'truebaseline_1', 'truebaseline_2', 'truebaseline_3', 'truebaseline_4', 'truebaseline_5', 'truebaseline_6', 'truebaseline_7', 'truebaseline_8', 'truebaseline_9', 'truebaseline_10', 'truebaseline_11', 'label'])\n"
]
},
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f4a76bf1580>]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"example = trainset.__getitem__(0)\n",
"\n",
"# We can plot the keys of the dictionary to see what is available in the standard dataset configuration. The 'truebaseline' \n",
"# keys are used for baseline correction in some datasets. \n",
"print(example.keys())\n",
"\n",
"plt.figure(figsize=(15, 4))\n",
"plt.plot(example['waveform'][0])"
]
},
{
"cell_type": "markdown",
"id": "87e2c812",
"metadata": {},
"source": [
"### Initializing the model\n",
"\n",
"We import the CNNDoubleResidual architecture from the ecgxai package and pass it the required hyperparameters to deal with our dataset (e.g. 12 channels, 5000 measurements per lead). We choose to subsample (half the spatial dimension) each layer to limit the required number of parameters. \n",
"\n",
"We also define adittional Sequential, Linear and Reshape modules to create a full pipeline that maps each ecg to a binary output. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d81d9c02",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:44:23.647229Z",
"start_time": "2022-05-17T11:44:23.157587Z"
}
},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"from ecgxai.network.doubleresidual.modules import CNNDoubleResidual, Reshape\n",
"\n",
"# Initualize the double reisdual convolutional resnet\n",
"cnn = CNNDoubleResidual(\n",
" num_layers=12,\n",
" in_sample_dim=5000,\n",
" in_channels=12,\n",
" kernel_size=7,\n",
" dropout_rate=0.1,\n",
" sub_sample_every=1,\n",
" double_channel_every=4,\n",
" act_func=nn.ReLU(),\n",
" batchnorm=True\n",
")\n",
"\n",
"# use the calculate_output_dim to see what size the output of the CNN is\n",
"cnn_output_dim, cnn_output_channels, cnn_output_samples = cnn.calculate_output_dim()\n",
"\n",
"# We pass this calculated size to a linear layer that then maps the data to a single output per ecg, the final reshape\n",
"# gets rid of additional dimensions of the output tensor (e.g. B x L is reshaped to B)\n",
"class_model = nn.Sequential(\n",
" cnn,\n",
" nn.Linear(cnn_output_dim, 1),\n",
" Reshape(-1)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "de8f815b",
"metadata": {},
"source": [
"### Printing the model archtecture using torchinfo\n",
"\n",
"Using the __torchinfo__ package we can now easily show how the shape of the data changes through the various layers of the chosen architecture and approximately how much compute is needed for each operation. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6ab318d0",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:44:27.797037Z",
"start_time": "2022-05-17T11:44:23.648456Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:652: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
},
{
"data": {
"text/plain": [
"==============================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==============================================================================================================\n",
"Sequential [16] --\n",
"├─CNNDoubleResidual: 1-1 [16, 192] --\n",
"│ └─Sequential: 2-1 [16, 192] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-1 [16, 12, 5000] --\n",
"│ │ │ └─MaxPool1d: 4-1 [16, 12, 5000] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-1 [16, 12, 5000] 24\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-2 [16, 12, 5000] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-3 [16, 12, 5006] --\n",
"│ │ │ │ └─Conv1d: 5-4 [16, 12, 5000] 1,020\n",
"│ │ │ │ └─BatchNorm1d: 5-5 [16, 12, 5000] 24\n",
"│ │ │ │ └─ReLU: 5-6 [16, 12, 5000] --\n",
"│ │ │ │ └─Dropout: 5-7 [16, 12, 5000] --\n",
"│ │ │ │ └─ConstantPad1d: 5-8 [16, 12, 5006] --\n",
"│ │ │ │ └─Conv1d: 5-9 [16, 12, 5000] 1,020\n",
"│ │ │ │ └─BatchNorm1d: 5-10 [16, 12, 5000] 24\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-2 [16, 12, 2500] --\n",
"│ │ │ └─MaxPool1d: 4-2 [16, 12, 2500] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-11 [16, 12, 5000] 24\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-12 [16, 12, 5000] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-13 [16, 12, 5006] --\n",
"│ │ │ │ └─Conv1d: 5-14 [16, 12, 2500] 1,020\n",
"│ │ │ │ └─BatchNorm1d: 5-15 [16, 12, 2500] 24\n",
"│ │ │ │ └─ReLU: 5-16 [16, 12, 2500] --\n",
"│ │ │ │ └─Dropout: 5-17 [16, 12, 2500] --\n",
"│ │ │ │ └─ConstantPad1d: 5-18 [16, 12, 2506] --\n",
"│ │ │ │ └─Conv1d: 5-19 [16, 12, 2500] 1,020\n",
"│ │ │ │ └─BatchNorm1d: 5-20 [16, 12, 2500] 24\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-3 [16, 12, 1250] --\n",
"│ │ │ └─MaxPool1d: 4-3 [16, 12, 1250] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-21 [16, 12, 2500] 24\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-22 [16, 12, 2500] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-23 [16, 12, 2506] --\n",
"│ │ │ │ └─Conv1d: 5-24 [16, 12, 1250] 1,020\n",
"│ │ │ │ └─BatchNorm1d: 5-25 [16, 12, 1250] 24\n",
"│ │ │ │ └─ReLU: 5-26 [16, 12, 1250] --\n",
"│ │ │ │ └─Dropout: 5-27 [16, 12, 1250] --\n",
"│ │ │ │ └─ConstantPad1d: 5-28 [16, 12, 1256] --\n",
"│ │ │ │ └─Conv1d: 5-29 [16, 12, 1250] 1,020\n",
"│ │ │ │ └─BatchNorm1d: 5-30 [16, 12, 1250] 24\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-4 [16, 24, 625] --\n",
"│ │ │ └─MaxPool1d: 4-4 [16, 24, 625] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-31 [16, 12, 1250] 24\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-32 [16, 12, 1250] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-33 [16, 12, 1256] --\n",
"│ │ │ │ └─Conv1d: 5-34 [16, 24, 625] 2,040\n",
"│ │ │ │ └─BatchNorm1d: 5-35 [16, 24, 625] 48\n",
"│ │ │ │ └─ReLU: 5-36 [16, 24, 625] --\n",
"│ │ │ │ └─Dropout: 5-37 [16, 24, 625] --\n",
"│ │ │ │ └─ConstantPad1d: 5-38 [16, 24, 631] --\n",
"│ │ │ │ └─Conv1d: 5-39 [16, 24, 625] 4,056\n",
"│ │ │ │ └─BatchNorm1d: 5-40 [16, 24, 625] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-5 [16, 24, 312] --\n",
"│ │ │ └─MaxPool1d: 4-5 [16, 24, 312] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-41 [16, 24, 625] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-42 [16, 24, 625] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-43 [16, 24, 631] --\n",
"│ │ │ │ └─Conv1d: 5-44 [16, 24, 313] 4,056\n",
"│ │ │ │ └─BatchNorm1d: 5-45 [16, 24, 313] 48\n",
"│ │ │ │ └─ReLU: 5-46 [16, 24, 313] --\n",
"│ │ │ │ └─Dropout: 5-47 [16, 24, 313] --\n",
"│ │ │ │ └─ConstantPad1d: 5-48 [16, 24, 318] --\n",
"│ │ │ │ └─Conv1d: 5-49 [16, 24, 312] 4,056\n",
"│ │ │ │ └─BatchNorm1d: 5-50 [16, 24, 312] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-6 [16, 24, 156] --\n",
"│ │ │ └─MaxPool1d: 4-6 [16, 24, 156] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-51 [16, 24, 312] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-52 [16, 24, 312] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-53 [16, 24, 318] --\n",
"│ │ │ │ └─Conv1d: 5-54 [16, 24, 156] 4,056\n",
"│ │ │ │ └─BatchNorm1d: 5-55 [16, 24, 156] 48\n",
"│ │ │ │ └─ReLU: 5-56 [16, 24, 156] --\n",
"│ │ │ │ └─Dropout: 5-57 [16, 24, 156] --\n",
"│ │ │ │ └─ConstantPad1d: 5-58 [16, 24, 162] --\n",
"│ │ │ │ └─Conv1d: 5-59 [16, 24, 156] 4,056\n",
"│ │ │ │ └─BatchNorm1d: 5-60 [16, 24, 156] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-7 [16, 24, 78] --\n",
"│ │ │ └─MaxPool1d: 4-7 [16, 24, 78] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-61 [16, 24, 156] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-62 [16, 24, 156] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-63 [16, 24, 162] --\n",
"│ │ │ │ └─Conv1d: 5-64 [16, 24, 78] 4,056\n",
"│ │ │ │ └─BatchNorm1d: 5-65 [16, 24, 78] 48\n",
"│ │ │ │ └─ReLU: 5-66 [16, 24, 78] --\n",
"│ │ │ │ └─Dropout: 5-67 [16, 24, 78] --\n",
"│ │ │ │ └─ConstantPad1d: 5-68 [16, 24, 84] --\n",
"│ │ │ │ └─Conv1d: 5-69 [16, 24, 78] 4,056\n",
"│ │ │ │ └─BatchNorm1d: 5-70 [16, 24, 78] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-8 [16, 48, 39] --\n",
"│ │ │ └─MaxPool1d: 4-8 [16, 48, 39] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-71 [16, 24, 78] 48\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-72 [16, 24, 78] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-73 [16, 24, 84] --\n",
"│ │ │ │ └─Conv1d: 5-74 [16, 48, 39] 8,112\n",
"│ │ │ │ └─BatchNorm1d: 5-75 [16, 48, 39] 96\n",
"│ │ │ │ └─ReLU: 5-76 [16, 48, 39] --\n",
"│ │ │ │ └─Dropout: 5-77 [16, 48, 39] --\n",
"│ │ │ │ └─ConstantPad1d: 5-78 [16, 48, 45] --\n",
"│ │ │ │ └─Conv1d: 5-79 [16, 48, 39] 16,176\n",
"│ │ │ │ └─BatchNorm1d: 5-80 [16, 48, 39] 96\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-9 [16, 48, 19] --\n",
"│ │ │ └─MaxPool1d: 4-9 [16, 48, 19] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-81 [16, 48, 39] 96\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-82 [16, 48, 39] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-83 [16, 48, 45] --\n",
"│ │ │ │ └─Conv1d: 5-84 [16, 48, 20] 16,176\n",
"│ │ │ │ └─BatchNorm1d: 5-85 [16, 48, 20] 96\n",
"│ │ │ │ └─ReLU: 5-86 [16, 48, 20] --\n",
"│ │ │ │ └─Dropout: 5-87 [16, 48, 20] --\n",
"│ │ │ │ └─ConstantPad1d: 5-88 [16, 48, 25] --\n",
"│ │ │ │ └─Conv1d: 5-89 [16, 48, 19] 16,176\n",
"│ │ │ │ └─BatchNorm1d: 5-90 [16, 48, 19] 96\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-10 [16, 48, 9] --\n",
"│ │ │ └─MaxPool1d: 4-10 [16, 48, 9] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-91 [16, 48, 19] 96\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-92 [16, 48, 19] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-93 [16, 48, 25] --\n",
"│ │ │ │ └─Conv1d: 5-94 [16, 48, 10] 16,176\n",
"│ │ │ │ └─BatchNorm1d: 5-95 [16, 48, 10] 96\n",
"│ │ │ │ └─ReLU: 5-96 [16, 48, 10] --\n",
"│ │ │ │ └─Dropout: 5-97 [16, 48, 10] --\n",
"│ │ │ │ └─ConstantPad1d: 5-98 [16, 48, 15] --\n",
"│ │ │ │ └─Conv1d: 5-99 [16, 48, 9] 16,176\n",
"│ │ │ │ └─BatchNorm1d: 5-100 [16, 48, 9] 96\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-11 [16, 48, 4] --\n",
"│ │ │ └─MaxPool1d: 4-11 [16, 48, 4] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-101 [16, 48, 9] 96\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ReLU: 5-102 [16, 48, 9] --\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3 -- --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─ConstantPad1d: 5-103 [16, 48, 15] --\n",
"│ │ │ │ └─Conv1d: 5-104 [16, 48, 5] 16,176\n",
"│ │ │ │ └─BatchNorm1d: 5-105 [16, 48, 5] 96\n",
"│ │ │ │ └─ReLU: 5-106 [16, 48, 5] --\n",
"│ │ │ │ └─Dropout: 5-107 [16, 48, 5] --\n",
"│ │ │ │ └─ConstantPad1d: 5-108 [16, 48, 10] --\n",
"│ │ │ │ └─Conv1d: 5-109 [16, 48, 4] 16,176\n",
"│ │ │ │ └─BatchNorm1d: 5-110 [16, 48, 4] 96\n",
"│ │ └─ResidualMaxPoolDoubleConvBlockForward: 3-12 [16, 96, 2] --\n",
"│ │ │ └─MaxPool1d: 4-12 [16, 96, 2] --\n",
"│ │ │ └─Sequential: 4 -- --\n",
"│ │ │ │ └─BatchNorm1d: 5-111 [16, 48, 4] 96\n",
"│ │ │ │ └─ReLU: 5-112 [16, 48, 4] --\n",
"│ │ │ │ └─ConstantPad1d: 5-113 [16, 48, 10] --\n",
"│ │ │ │ └─Conv1d: 5-114 [16, 96, 2] 32,352\n",
"│ │ │ │ └─BatchNorm1d: 5-115 [16, 96, 2] 192\n",
"│ │ │ │ └─ReLU: 5-116 [16, 96, 2] --\n",
"│ │ │ │ └─Dropout: 5-117 [16, 96, 2] --\n",
"│ │ │ │ └─ConstantPad1d: 5-118 [16, 96, 8] --\n",
"│ │ │ │ └─Conv1d: 5-119 [16, 96, 2] 64,608\n",
"│ │ │ │ └─BatchNorm1d: 5-120 [16, 96, 2] 192\n",
"│ │ └─Flatten: 3-13 [16, 192] --\n",
"├─Linear: 1-2 [16, 1] 193\n",
"├─Reshape: 1-3 [16] --\n",
"==============================================================================================================\n",
"Total params: 257,401\n",
"Trainable params: 257,401\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 453.13\n",
"==============================================================================================================\n",
"Input size (MB): 3.84\n",
"Forward/backward pass size (MB): 95.19\n",
"Params size (MB): 1.03\n",
"Estimated Total Size (MB): 100.06\n",
"=============================================================================================================="
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchinfo import summary\n",
"\n",
"batch_size = 16\n",
"summary(class_model, input_size=(batch_size, 12, 5000), depth=50)\n"
]
},
{
"cell_type": "markdown",
"id": "e37ec85b",
"metadata": {},
"source": [
"### Defining metrics\n",
"\n",
"As this is a binary classification problem we would like to measure the performance of the model using common metrics such as AUROC, Precision, Recall, Accuracy and F1 score. We use the implemention of these metrics provided by the __TorchMetrics__ package. Each metric is wrapped using a __TorchMetricWrapper (TMW)__ provided by the ecgxai package. The TMW are intialized using a TorchMetric instance, the names of the model outputs that should be passed to the metric and an 'int_arg' parameter. The possible choices for output names can be found in the class definition of the classification system. In this case we use 'y_prob' (the predicted class probability) and 'label' (the true label of each sample). The int_args parameter is used to cast the 'label' output to an integer (originally float) which is required by the TorchMetric package. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "64fd1e91",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:50:31.691005Z",
"start_time": "2022-05-17T11:50:31.662974Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs)\n"
]
}
],
"source": [
"from ecgxai.utils.metrics import TMW\n",
"from torchmetrics import MetricCollection, AUROC, F1Score, Accuracy, Precision, Recall\n",
"\n",
"\n",
"metrics = MetricCollection(\n",
" {\n",
" 'AUROC': TMW(AUROC(), ['y_prob', 'label'], int_args=['label']),\n",
" 'Precision': TMW(Precision(), ['y_prob', 'label'], int_args=['label']),\n",
" 'Recall': TMW(Recall(), ['y_prob', 'label'], int_args=['label']),\n",
" 'Accuracy': TMW(Accuracy(), ['y_prob', 'label'], int_args=['label']),\n",
" 'F1': TMW(F1Score(), ['y_prob', 'label'], int_args=['label'])\n",
" } \n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "7480227e",
"metadata": {},
"source": [
"### Defining a loss function\n",
"\n",
"As this is a binary classification problem we choose te commonly used BCEWithLogitsLoss from the torch package to train our model. This loss is wrapped using the ecgxai TorchWrapper (TW) class, which can be used similarly to the TMW class used to wrap the metrics. In this case we however use the 'y_hat' model output instead of the 'y_prob', as we would like the model output before applying a sigmoid. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f357bf4e",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:50:35.600002Z",
"start_time": "2022-05-17T11:50:35.534565Z"
}
},
"outputs": [],
"source": [
"from ecgxai.utils.loss import TW\n",
"\n",
"loss = TW(nn.BCEWithLogitsLoss(), ['y_hat', 'label'])"
]
},
{
"cell_type": "markdown",
"id": "34c391c6",
"metadata": {},
"source": [
"### Defining a classification system\n",
"\n",
"Now we can define our Binary classification system. The system automatically handles the calculation of the loss and metrics. It also offers a variaty of automatic logging opperations for training, validation and testing. In this case we are only interested in the test_metrics, hence we only pass the defined metrics there. We also pass our model, the defined loss and a learning rate (lr) of 0.001. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3b54ab7b",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:50:37.352599Z",
"start_time": "2022-05-17T11:50:37.325621Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['y_hat', 'label']\n"
]
}
],
"source": [
"from ecgxai.systems.classification_system import ClassificationSystem\n",
"\n",
"system = ClassificationSystem(\n",
" lr=0.001,\n",
" model=class_model,\n",
" test_metrics=metrics,\n",
" loss=loss,\n",
" mode='binary'\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0277de43",
"metadata": {},
"source": [
"### Training and testing the model\n",
"\n",
"To train the model we use __pytorch lighting__. We tell the pytorch lighting trainer object to train the model for 5 epochs and to save the model once done training. Additionally the trainer object will move all the required instances to the GPU if cuda is available to speed up training. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "446a400c",
"metadata": {
"ExecuteTime": {
"end_time": "2022-05-17T11:52:04.534202Z",
"start_time": "2022-05-17T11:51:23.898677Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True, used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:122: UserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n",
" rank_zero_warn(\"You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\")\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
"\n",
" | Name | Type | Params\n",
"--------------------------------------------------\n",
"0 | test_metrics | MetricCollection | 0 \n",
"1 | loss | TW | 0 \n",
"2 | model | Sequential | 257 K \n",
"--------------------------------------------------\n",
"257 K Trainable params\n",
"0 Non-trainable params\n",
"257 K Total params\n",
"1.030 Total estimated model params size (MB)\n",
"/opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:631: UserWarning: Checkpoint directory /workspace/ecgxai/examples/classification/checkpoints exists and is not empty.\n",
" rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "de2639e0e68d444ba6ca99638f3e0cef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pytorch_lightning as pl\n",
"import torch\n",
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
"\n",
"trainer = pl.Trainer(\n",
" max_epochs=5,\n",
" gpus= 1 if torch.cuda.is_available() else None,\n",
" logger=None,\n",
" callbacks=[\n",
" ModelCheckpoint(\n",
" save_last=True\n",
" ),\n",
" ]\n",
" )\n",
"\n",
"trainer.fit(system, train_loader)\n"
]
},
{
"cell_type": "markdown",
"id": "4988da29",
"metadata": {},
"source": [
"### Testing the model\n",
"\n",
"Now all we need to do is test the model. "
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1adb4979",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0bec306c36514f07952a60d22b65c1cf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Testing: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:59: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
" warning_cache.warn(\n",
"/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: No positive samples in targets, true positive value should be meaningless. Returning zero tensor in true positive score\n",
" warnings.warn(*args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"DATALOADER:0 TEST RESULTS\n",
"{'test_AUROC': 0.9975426197052002,\n",
" 'test_Accuracy': 0.9882909655570984,\n",
" 'test_F1': 0.9247804880142212,\n",
" 'test_Precision': 0.9445728063583374,\n",
" 'test_Recall': 0.9064620137214661,\n",
" 'test_loss': 0.03978169709444046}\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:59: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
" warning_cache.warn(\n"
]
},
{
"data": {
"text/plain": [
"[{'test_AUROC': 0.9975426197052002,\n",
" 'test_Accuracy': 0.9882909655570984,\n",
" 'test_F1': 0.9247804880142212,\n",
" 'test_Precision': 0.9445728063583374,\n",
" 'test_Recall': 0.9064620137214661,\n",
" 'test_loss': 0.03978169709444046}]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.test(model=system, dataloaders=test_loader)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autoclose": false,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}