[dc3c86]: / MRNet_Baseline_Models.ipynb

Download this file

581 lines (580 with data), 18.6 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MRNet Baseline Models\n",
    "Seeking to replicate and extend basic CNN models on the MRNet data.\n",
    "\n",
    "You should have run the `save_middle_slices_as_images.ipynb` notebook to generate directories containing just the center slice of the scans from the three planes, directory `mid1`, as well as an RGB image generated from the three centered slices from the scans, skipping d={0,1,2} slices between the taken slices, eg directory `mid3d2`.\n",
    "\n",
    "For each model architecture (eg AlexNet or ResNet50), there are thus nine models to fit, for each combination of outcome and plane. Each architecture will have a corresponding 3x3 model performance grid. The competition measure is average AUC across outcome (Abnormal, Meniscus tear, and ACL tear). Simple first approach is to keep predictions for an outcome based on the best model of only scans from one plane.\n",
    "\n",
    "Outcomes predicted:\n",
    "- Abnormal/Normal\n",
    "- Meniscus tear/Not\n",
    "- ACL tear/Not\n",
    "\n",
    "Input images from planes:\n",
    "- Axial\n",
    "- Coronal\n",
    "- Sagittal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import pickle\n",
    "from time import strftime, gmtime \n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from fastai.vision import *\n",
    "import torch\n",
    "\n",
    "# for AUC ROC score calculations\n",
    "from sklearn import metrics\n",
    "\n",
    "from operator import itemgetter \n",
    "\n",
    "#from mrnet_orig import *\n",
    "from mrnet_itemlist import *\n",
    "\n",
    "#from ipywidgets import interact, Dropdown, IntSlider\n",
    "\n",
    "%matplotlib notebook\n",
    "plt.style.use('grayscale')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = Path('../data/mid3d0') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_abnl = pd.read_csv(data_path/'train-abnormal.csv', header=None,\n",
    "                       names=['Case', 'Abnormal'], \n",
    "                       dtype={'Case': str, 'Abnormal': np.int64})\n",
    "train_abnl['axial'] = 'train/axial/' + train_abnl.Case + '.png'\n",
    "train_abnl['coronal'] = 'train/coronal/' + train_abnl.Case + '.png'\n",
    "train_abnl['sagittal'] = 'train/sagittal/' + train_abnl.Case + '.png'\n",
    "\n",
    "valid_abnl = pd.read_csv(data_path/'valid-abnormal.csv', header=None,\n",
    "                       names=['Case', 'Abnormal'], \n",
    "                       dtype={'Case': str, 'Abnormal': np.int64})\n",
    "valid_abnl['axial'] = 'valid/axial/' + valid_abnl.Case + '.png'\n",
    "valid_abnl['coronal'] = 'valid/coronal/' + valid_abnl.Case + '.png'\n",
    "valid_abnl['sagittal'] = 'valid/sagittal/' + valid_abnl.Case + '.png'\n",
    "abnl = train_abnl.append(valid_abnl, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_meni = pd.read_csv(data_path/'train-meniscus.csv', header=None,\n",
    "                       names=['Case', 'Meniscus'], \n",
    "                       dtype={'Case': str, 'Meniscus': np.int64})\n",
    "\n",
    "valid_meni = pd.read_csv(data_path/'valid-meniscus.csv', header=None,\n",
    "                       names=['Case', 'Meniscus'], \n",
    "                       dtype={'Case': str, 'Meniscus': np.int64})\n",
    "meni = train_meni.append(valid_meni, ignore_index=True)\n",
    "abnl = pd.merge(abnl, meni, on='Case')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acl = pd.read_csv(data_path/'train-acl.csv', header=None,\n",
    "                       names=['Case', 'ACL'], \n",
    "                       dtype={'Case': str, 'ACL': np.int64})\n",
    "\n",
    "valid_acl = pd.read_csv(data_path/'valid-acl.csv', header=None,\n",
    "                       names=['Case', 'ACL'], \n",
    "                       dtype={'Case': str, 'ACL': np.int64})\n",
    "acl = train_acl.append(valid_acl, ignore_index=True)\n",
    "abnl = pd.merge(abnl, acl, on='Case')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Case</th>\n",
       "      <th>Abnormal</th>\n",
       "      <th>axial</th>\n",
       "      <th>coronal</th>\n",
       "      <th>sagittal</th>\n",
       "      <th>Meniscus</th>\n",
       "      <th>ACL</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000</td>\n",
       "      <td>1</td>\n",
       "      <td>train/axial/0000.png</td>\n",
       "      <td>train/coronal/0000.png</td>\n",
       "      <td>train/sagittal/0000.png</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0001</td>\n",
       "      <td>1</td>\n",
       "      <td>train/axial/0001.png</td>\n",
       "      <td>train/coronal/0001.png</td>\n",
       "      <td>train/sagittal/0001.png</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0002</td>\n",
       "      <td>1</td>\n",
       "      <td>train/axial/0002.png</td>\n",
       "      <td>train/coronal/0002.png</td>\n",
       "      <td>train/sagittal/0002.png</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0003</td>\n",
       "      <td>1</td>\n",
       "      <td>train/axial/0003.png</td>\n",
       "      <td>train/coronal/0003.png</td>\n",
       "      <td>train/sagittal/0003.png</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0004</td>\n",
       "      <td>1</td>\n",
       "      <td>train/axial/0004.png</td>\n",
       "      <td>train/coronal/0004.png</td>\n",
       "      <td>train/sagittal/0004.png</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Case  Abnormal                 axial                 coronal  \\\n",
       "0  0000         1  train/axial/0000.png  train/coronal/0000.png   \n",
       "1  0001         1  train/axial/0001.png  train/coronal/0001.png   \n",
       "2  0002         1  train/axial/0002.png  train/coronal/0002.png   \n",
       "3  0003         1  train/axial/0003.png  train/coronal/0003.png   \n",
       "4  0004         1  train/axial/0004.png  train/coronal/0004.png   \n",
       "\n",
       "                  sagittal  Meniscus  ACL  \n",
       "0  train/sagittal/0000.png         0    0  \n",
       "1  train/sagittal/0001.png         1    1  \n",
       "2  train/sagittal/0002.png         0    0  \n",
       "3  train/sagittal/0003.png         1    0  \n",
       "4  train/sagittal/0004.png         0    0  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abnl.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1250, 7)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abnl.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Restrict attention to Cases for which you have data\n",
    "For local development, you might be working just with a subset of the MRNet data. If so, get the list of Cases you have data for and subset the dataframe accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cases = [e[:-4] for e in os.listdir(data_path/'train/axial') if e[-4:] == '.png']\n",
    "valid_cases = [e[:-4] for e in os.listdir(data_path/'valid/axial') if e[-4:] == '.png']\n",
    "cases_w_data = sorted(train_cases + valid_cases)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1250"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(cases_w_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = abnl.loc[abnl.Case.isin(cases_w_data),:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1250, 7)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if necessary, create /models by running the following\n",
    "#!mkdir models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_path = Path('./models')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cycling through planes and outcomes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Using AlexNet to predict ACL using axial middle slice(s) [data in ../data/mid3d0] ----\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='4' class='' max='10', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      40.00% [4/10 00:51<01:17]\n",
       "    </div>\n",
       "    \n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.970498</td>\n",
       "      <td>0.689940</td>\n",
       "      <td>0.548000</td>\n",
       "      <td>00:12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.818568</td>\n",
       "      <td>0.504058</td>\n",
       "      <td>0.784000</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.663493</td>\n",
       "      <td>0.458726</td>\n",
       "      <td>0.784000</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.564507</td>\n",
       "      <td>0.486030</td>\n",
       "      <td>0.792000</td>\n",
       "      <td>00:12</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>\n",
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='0' class='' max='15', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      0.00% [0/15 00:00<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "accresults = []\n",
    "aucresults = []\n",
    "\n",
    "for outcome in ('Abnormal','Meniscus','ACL'):\n",
    "    for plane in ('axial','coronal','sagittal'):\n",
    "        print('---- Using AlexNet to predict {} using {} middle slice(s) [data in {}] ----'.format(outcome, plane, data_path))\n",
    "        data  = ImageDataBunch.from_df(path=data_path, df=df, \n",
    "                                       fn_col=plane, label_col=outcome, bs=64)\n",
    "        learn = cnn_learner(data, models.alexnet, metrics=accuracy)\n",
    "        learn.fit_one_cycle(10)\n",
    "        # collect accuracy metrics\n",
    "        acc     = [float(e[0]) for e in learn.recorder.metrics]\n",
    "        accresults.append([outcome,plane,acc])\n",
    "        # collect AUC ROC scores\n",
    "        yhat, y = learn.get_preds(ds_type=DatasetType.Valid)\n",
    "        # TODO: save y and predicted y\n",
    "        \n",
    "        auc     = metrics.roc_auc_score(to_np(y), to_np(yhat)[:,1])\n",
    "        aucresults.append([outcome,plane,auc])    \n",
    "        \n",
    "# save accuracy metrics\n",
    "with open(models_path/('accresults_' + data_path.stem + '_AlexNet_' +  strftime('%Y%m%d_%H%M', gmtime())), 'wb') as f:\n",
    "    pickle.dump(accresults, f)\n",
    "\n",
    "# save AUC scores\n",
    "with open(models_path/('aucresults_' + data_path.stem + '_AlexNet_' +  strftime('%Y%m%d_%H%M', gmtime())), 'wb') as f:\n",
    "    pickle.dump(aucresults, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['Meniscus', 'axial', 0.7259965468057953], ['Meniscus', 'coronal', 0.6765074962023645], ['Meniscus', 'sagittal', 0.6744101103075463], ['Abnormal', 'axial', 0.8662922455414326], ['Abnormal', 'sagittal', 0.8261897723913686], ['Abnormal', 'coronal', 0.7856420626895854], ['ACL', 'axial', 0.756035077347522], ['ACL', 'sagittal', 0.738572295949345], ['ACL', 'coronal', 0.7013403263403263]]\n",
      "[['Meniscus', 'axial', 0.7259965468057953], ['Abnormal', 'axial', 0.8662922455414326], ['ACL', 'axial', 0.756035077347522]]\n",
      "Average AUC across tasks: 0.78\n"
     ]
    }
   ],
   "source": [
    "sorted_auc = sorted(aucresults, key=itemgetter(0,2), reverse=True)\n",
    "print(sorted_auc)\n",
    "high_auc_per_task = sorted_auc[0::3]\n",
    "print(high_auc_per_task)\n",
    "ave_auc_across_tasks = np.mean([e[2] for e in high_auc_per_task])\n",
    "print('Average AUC across tasks: {}'.format(np.round(ave_auc_across_tasks,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['Meniscus', 'axial', 0.7259965468057953], ['Meniscus', 'coronal', 0.6765074962023645], ['Meniscus', 'sagittal', 0.6744101103075463], ['Abnormal', 'axial', 0.8662922455414326], ['Abnormal', 'sagittal', 0.8261897723913686], ['Abnormal', 'coronal', 0.7856420626895854], ['ACL', 'axial', 0.756035077347522], ['ACL', 'sagittal', 0.738572295949345], ['ACL', 'coronal', 0.7013403263403263]]\n",
      "[['Meniscus', 'axial', 0.7259965468057953], ['Abnormal', 'axial', 0.8662922455414326], ['ACL', 'axial', 0.756035077347522]]\n",
      "Average AUC across tasks: 0.78\n"
     ]
    }
   ],
   "source": [
    "sorted_auc = sorted(aucresults, key=itemgetter(0,2), reverse=True)\n",
    "print(sorted_auc)\n",
    "high_auc_per_task = sorted_auc[0::3]\n",
    "print(high_auc_per_task)\n",
    "ave_auc_across_tasks = np.mean([e[2] for e in high_auc_per_task])\n",
    "print('Average AUC across tasks: {}'.format(np.round(ave_auc_across_tasks,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}