Switch to side-by-side view

--- a
+++ b/code/Finetuning-Example.ipynb
@@ -0,0 +1,375 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Load your data\n",
+    "\n",
+    "Before finetuning a pretrained model of the experiments we provide in our repository (or precomputed and provided [here](https://datacloud.hhi.fraunhofer.de/nextcloud/s/NCjYws3mamLrkKq)), first load your custom 100 Hz sampled 12-lead ECG signal data `X` of shape `[N,L,12]` in Millivolts (mV) and multi-hot encoded labels `y` of shape `[N,C]` as numpy arrays, where `C` is the number of classes and `N` the number of total samples in this dataset. Although PTB-XL comes with fixed `L=1000` (i,e. 10 seconds), it is not required to be fixed, **BUT** the shortest sample must be longer than `input_size` of the specific model (e.g. 2.5 seconds for our fastai-models).\n",
+    "\n",
+    "For proper tinetuning split your data into four numpy arrays: `X_train`,`y_train`,`X_val` and `y_val`\n",
+    "\n",
+    "### Example: finetune model trained on all (71) on superdiagnostic (5)\n",
+    "Below we provide an example for loading [PTB-XL](https://physionet.org/content/ptb-xl/1.0.1/) aggregated at the `superdiagnostic` level, where we use the provided folds for train-validation-split:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "((19267, 1000, 12), (19267, 5), (2163, 1000, 12), (2163, 5))"
+      ]
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from utils import utils\n",
+    "\n",
+    "sampling_frequency=100\n",
+    "datafolder='../data/ptbxl/'\n",
+    "task='superdiagnostic'\n",
+    "outputfolder='../output/'\n",
+    "\n",
+    "# Load PTB-XL data\n",
+    "data, raw_labels = utils.load_dataset(datafolder, sampling_frequency)\n",
+    "# Preprocess label data\n",
+    "labels = utils.compute_label_aggregations(raw_labels, datafolder, task)\n",
+    "# Select relevant data and convert to one-hot\n",
+    "data, labels, Y, _ = utils.select_data(data, labels, task, min_samples=0, outputfolder=outputfolder)\n",
+    "\n",
+    "# 1-9 for training \n",
+    "X_train = data[labels.strat_fold < 10]\n",
+    "y_train = Y[labels.strat_fold < 10]\n",
+    "# 10 for validation\n",
+    "X_val = data[labels.strat_fold == 10]\n",
+    "y_val = Y[labels.strat_fold == 10]\n",
+    "\n",
+    "num_classes = 5         # <=== number of classes in the finetuning dataset\n",
+    "input_shape = [1000,12] # <=== shape of samples, [None, 12] in case of different lengths\n",
+    "\n",
+    "X_train.shape, y_train.shape, X_val.shape, y_val.shape"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Train or download models\n",
+    "There are two possibilities:\n",
+    "   1. Run the experiments as described in README. Afterwards you find trained in models in `output/expX/models/`\n",
+    "   2. Download the precomputed `output`-folder with all experiments and models from [here]((https://datacloud.hhi.fraunhofer.de/nextcloud/s/NCjYws3mamLrkKq))\n",
+    "\n",
+    "# Load pretrained model\n",
+    "\n",
+    "For loading a pretrained model:\n",
+    "   1. specify `modelname` which can be seen in `code/configs/` (e.g. `modelname='fastai_xresnet1d101'`)\n",
+    "   2. provide `experiment` to build the path `pretrainedfolder` (here: `exp0` refers to the experiment with `all` 71 SCP-statements)\n",
+    "   \n",
+    "This returns the pretrained model where the classification is replaced by a random initialized head with the same number of outputs as the number of classes."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from models.fastai_model import fastai_model\n",
+    "\n",
+    "experiment = 'exp0'\n",
+    "modelname = 'fastai_xresnet1d101'\n",
+    "pretrainedfolder = '../output/'+experiment+'/models/'+modelname+'/'\n",
+    "mpath='../output/' # <=== path where the finetuned model will be stored\n",
+    "n_classes_pretrained = 71 # <=== because we load the model from exp0, this should be fixed because this depends the experiment\n",
+    "\n",
+    "model = fastai_model(\n",
+    "    modelname, \n",
+    "    num_classes, \n",
+    "    sampling_frequency, \n",
+    "    mpath, \n",
+    "    input_shape=input_shape, \n",
+    "    pretrainedfolder=pretrainedfolder,\n",
+    "    n_classes_pretrained=n_classes_pretrained, \n",
+    "    pretrained=True,\n",
+    "    epochs_finetuning=2,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Preprocess data with pretrained Standardizer\n",
+    "\n",
+    "Since we standardize inputs to zero mean and unit variance, your custom data needs to be standardized with the respective mean and variance. This is also provided in the respective experiment folder `output/expX/data/standard_scaler.pkl`"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/wagner/anaconda3/envs/ecg_python37/lib/python3.7/site-packages/sklearn/utils/deprecation.py:143: FutureWarning: The sklearn.preprocessing.data module is  deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.preprocessing. Anything that cannot be imported from sklearn.preprocessing is now part of the private API.\n",
+      "  warnings.warn(message, FutureWarning)\n"
+     ]
+    },
+    {
+     "ename": "RuntimeError",
+     "evalue": "The reset parameter is False but there is no n_features_in_ attribute. Is this estimator fitted?",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-4-5fe1be60124f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mstandard_scaler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'../output/'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mexperiment\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'/data/standard_scaler.pkl'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mX_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_standardizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstandard_scaler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0mX_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_standardizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstandard_scaler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/github/ecg_ptbxl_benchmarking/code/utils/utils.py\u001b[0m in \u001b[0;36mapply_standardizer\u001b[0;34m(X, ss)\u001b[0m\n\u001b[1;32m    329\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    330\u001b[0m         \u001b[0mx_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 331\u001b[0;31m         \u001b[0mX_tmp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnewaxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    332\u001b[0m     \u001b[0mX_tmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_tmp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    333\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mX_tmp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/ecg_python37/lib/python3.7/site-packages/sklearn/preprocessing/_data.py\u001b[0m in \u001b[0;36mtransform\u001b[0;34m(self, X, copy)\u001b[0m\n\u001b[1;32m    792\u001b[0m                                 \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'csr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    793\u001b[0m                                 \u001b[0mestimator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFLOAT_DTYPES\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 794\u001b[0;31m                                 force_all_finite='allow-nan')\n\u001b[0m\u001b[1;32m    795\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    796\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0missparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/ecg_python37/lib/python3.7/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36m_validate_data\u001b[0;34m(self, X, y, reset, validate_separately, **check_params)\u001b[0m\n\u001b[1;32m    434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    435\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcheck_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'ensure_2d'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 436\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_n_features\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    437\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    438\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/ecg_python37/lib/python3.7/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36m_check_n_features\u001b[0;34m(self, X, reset)\u001b[0m\n\u001b[1;32m    371\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_features_in_'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    372\u001b[0m                 raise RuntimeError(\n\u001b[0;32m--> 373\u001b[0;31m                     \u001b[0;34m\"The reset parameter is False but there is no \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    374\u001b[0m                     \u001b[0;34m\"n_features_in_ attribute. Is this estimator fitted?\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    375\u001b[0m                 )\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: The reset parameter is False but there is no n_features_in_ attribute. Is this estimator fitted?"
+     ]
+    }
+   ],
+   "source": [
+    "import pickle\n",
+    "\n",
+    "standard_scaler = pickle.load(open('../output/'+experiment+'/data/standard_scaler.pkl', \"rb\"))\n",
+    "\n",
+    "X_train = utils.apply_standardizer(X_train, standard_scaler)\n",
+    "X_val = utils.apply_standardizer(X_val, standard_scaler)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Finetune model\n",
+    "\n",
+    "Calling `model.fit` of a model with `pretrained=True` will perform finetuning as proposed in our work i.e. **gradual unfreezing and discriminative learning rates**. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Finetuning...\n",
+      "model: fastai_xresnet1d101\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: left;\">\n",
+       "      <th>epoch</th>\n",
+       "      <th>train_loss</th>\n",
+       "      <th>valid_loss</th>\n",
+       "      <th>time</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <td>0</td>\n",
+       "      <td>0.271710</td>\n",
+       "      <td>0.271859</td>\n",
+       "      <td>00:27</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>1</td>\n",
+       "      <td>0.237324</td>\n",
+       "      <td>0.268521</td>\n",
+       "      <td>00:24</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: left;\">\n",
+       "      <th>epoch</th>\n",
+       "      <th>train_loss</th>\n",
+       "      <th>valid_loss</th>\n",
+       "      <th>time</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <td>0</td>\n",
+       "      <td>0.230869</td>\n",
+       "      <td>0.270421</td>\n",
+       "      <td>00:33</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>1</td>\n",
+       "      <td>0.230395</td>\n",
+       "      <td>0.268627</td>\n",
+       "      <td>00:33</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "model.fit(X_train, y_train, X_val, y_val)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Evaluate model on validation data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "aggregating predictions...\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>macro_auc</th>\n",
+       "      <th>Fmax</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>0.931458</td>\n",
+       "      <td>0.827961</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   macro_auc      Fmax\n",
+       "0   0.931458  0.827961"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y_val_pred = model.predict(X_val)\n",
+    "utils.evaluate_experiment(y_val, y_val_pred)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.7.4 64-bit ('ecg_python37': conda)",
+   "language": "python",
+   "name": "python37464bitecgpython37condacca13046f89242bdbe235da55b7380ab"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}