[dc3c86]: / MRNet_EDA_ns.ipynb

Download this file

1216 lines (1215 with data), 39.9 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from fastai.core import *\n",
    "\n",
    "%matplotlib notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LICENSE         MRNet_EDA.ipynb README.md\r\n"
     ]
    }
   ],
   "source": [
    "! ls -R "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = Path('../data')\n",
    "train_path = data_path/'smalltrain'/'train'\n",
    "valid_path = data_path/'smallvalid'/'valid'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "          Case\n",
      "Abnormal      \n",
      "0          217\n",
      "1          913\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Case</th>\n",
       "      <th>Abnormal</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0001</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0002</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0003</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0004</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Case  Abnormal\n",
       "0  0000         1\n",
       "1  0001         1\n",
       "2  0002         1\n",
       "3  0003         1\n",
       "4  0004         1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_abnl = pd.read_csv(data_path/'train-abnormal.csv', header=None,\n",
    "                       names=['Case', 'Abnormal'], \n",
    "                       dtype={'Case': str, 'Abnormal': np.int64})\n",
    "print(train_abnl.groupby('Abnormal').count())\n",
    "train_abnl.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "          Case\n",
      "ACL_tear      \n",
      "0          922\n",
      "1          208\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Case</th>\n",
       "      <th>ACL_tear</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0001</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0002</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0003</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0004</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Case  ACL_tear\n",
       "0  0000         0\n",
       "1  0001         1\n",
       "2  0002         0\n",
       "3  0003         0\n",
       "4  0004         0"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_acl = pd.read_csv(data_path/'train-acl.csv', header=None,\n",
    "                       names=['Case', 'ACL_tear'], \n",
    "                       dtype={'Case': str, 'ACL_tear': np.int64})\n",
    "print(train_acl.groupby('ACL_tear').count())\n",
    "train_acl.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               Case\n",
      "Meniscus_tear      \n",
      "0               733\n",
      "1               397\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Case</th>\n",
       "      <th>Meniscus_tear</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0001</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0002</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0003</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0004</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Case  Meniscus_tear\n",
       "0  0000              0\n",
       "1  0001              1\n",
       "2  0002              0\n",
       "3  0003              1\n",
       "4  0004              0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_meniscus = pd.read_csv(data_path/'train-meniscus.csv', header=None,\n",
    "                       names=['Case', 'Meniscus_tear'], \n",
    "                       dtype={'Case': str, 'Meniscus_tear': np.int64})\n",
    "print(train_meniscus.groupby('Meniscus_tear').count())\n",
    "train_meniscus.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Co-occurrence of ACL and Meniscus tears"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.merge(train_abnl, train_acl, on='Case')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.merge(train, train_meniscus, on='Case')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Case</th>\n",
       "      <th>Abnormal</th>\n",
       "      <th>ACL_tear</th>\n",
       "      <th>Meniscus_tear</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0001</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0002</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0003</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0004</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Case  Abnormal  ACL_tear  Meniscus_tear\n",
       "0  0000         1         0              0\n",
       "1  0001         1         1              1\n",
       "2  0002         1         0              0\n",
       "3  0003         1         0              1\n",
       "4  0004         1         0              0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>Case</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Abnormal</th>\n",
       "      <th>ACL_tear</th>\n",
       "      <th>Meniscus_tear</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">1</th>\n",
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
       "      <th>0</th>\n",
       "      <td>433</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>272</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">1</th>\n",
       "      <th>0</th>\n",
       "      <td>83</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>125</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 Case\n",
       "Abnormal ACL_tear Meniscus_tear      \n",
       "0        0        0               217\n",
       "1        0        0               433\n",
       "                  1               272\n",
       "         1        0                83\n",
       "                  1               125"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(train.head())\n",
    "display(train.groupby(['Abnormal','ACL_tear','Meniscus_tear']).count())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that cases considered Abnormal but without either ACL or Meniscus tear are the most common category, and ACL tears without Meniscus tear is the least common case in the training sample."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load stacks/sequences of images from each plane\n",
    "Files are saved as NumPy arrays. Scans were taken from each of three planes, axial, coronal, and sagittal. For each plane, the scan results in a set of images.  \n",
    "\n",
    "First, let's check for variation in the number of images per sequence, and in the image dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_stack_dims(case_df, data_path=train_path):\n",
    "    cases = list(case_df.Case)\n",
    "    data = []\n",
    "    for case in cases:\n",
    "        row = [case]\n",
    "        for plane in ['axial', 'coronal', 'sagittal']:\n",
    "            fpath = data_path/plane/'{}.npy'.format(case)\n",
    "            try: \n",
    "                s,w,h = np.load(fpath).shape \n",
    "                row.extend([s,w,h])\n",
    "            except FileNotFoundError:\n",
    "                continue\n",
    "#        print('{}: {}'.format(case,row))\n",
    "        if len(row)==10: data.append(row)\n",
    "    columns=['Case',\n",
    "             'axial_s','axial_w','axial_h',\n",
    "             'coronal_s','coronal_w','coronal_h',\n",
    "             'sagittal_s','sagittal_w','sagittal_h',\n",
    "            ]\n",
    "    data_dict = {}\n",
    "    for i,k in enumerate(columns): data_dict[k] = [row[i] for row in data]\n",
    "    return pd.DataFrame(data_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "dimdf = collect_stack_dims(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>axial_s</th>\n",
       "      <th>axial_w</th>\n",
       "      <th>axial_h</th>\n",
       "      <th>coronal_s</th>\n",
       "      <th>coronal_w</th>\n",
       "      <th>coronal_h</th>\n",
       "      <th>sagittal_s</th>\n",
       "      <th>sagittal_w</th>\n",
       "      <th>sagittal_h</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>50.000000</td>\n",
       "      <td>50.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>50.000000</td>\n",
       "      <td>50.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>50.00000</td>\n",
       "      <td>50.0</td>\n",
       "      <td>50.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>35.860000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>31.360000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>31.72000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>7.050865</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.899264</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>6.35687</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>22.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>18.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>19.00000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>32.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>24.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>26.25000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>37.500000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>32.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>32.00000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>40.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>37.750000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>36.00000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>51.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>46.000000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>46.00000</td>\n",
       "      <td>256.0</td>\n",
       "      <td>256.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         axial_s  axial_w  axial_h  coronal_s  coronal_w  coronal_h  \\\n",
       "count  50.000000     50.0     50.0  50.000000       50.0       50.0   \n",
       "mean   35.860000    256.0    256.0  31.360000      256.0      256.0   \n",
       "std     7.050865      0.0      0.0   7.899264        0.0        0.0   \n",
       "min    22.000000    256.0    256.0  18.000000      256.0      256.0   \n",
       "25%    32.000000    256.0    256.0  24.000000      256.0      256.0   \n",
       "50%    37.500000    256.0    256.0  32.000000      256.0      256.0   \n",
       "75%    40.000000    256.0    256.0  37.750000      256.0      256.0   \n",
       "max    51.000000    256.0    256.0  46.000000      256.0      256.0   \n",
       "\n",
       "       sagittal_s  sagittal_w  sagittal_h  \n",
       "count    50.00000        50.0        50.0  \n",
       "mean     31.72000       256.0       256.0  \n",
       "std       6.35687         0.0         0.0  \n",
       "min      19.00000       256.0       256.0  \n",
       "25%      26.25000       256.0       256.0  \n",
       "50%      32.00000       256.0       256.0  \n",
       "75%      36.00000       256.0       256.0  \n",
       "max      46.00000       256.0       256.0  "
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dimdf.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The number of images in a set varies from case (patient) to case, and the dimensions of each image is the same, 256x256. In the sample of data collected here, axial sequences range in length from 22 to 51; coronal, from 18 to 46; sagittal, from 19 to 46."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_one_stack(case, data_path=train_path, plane='coronal'):\n",
    "    fpath = data_path/plane/'{}.npy'.format(case)\n",
    "    return np.load(fpath)\n",
    "\n",
    "def load_stacks(case):\n",
    "    x = {}\n",
    "    planes = ['axial', 'coronal', 'sagittal']\n",
    "    for i, plane in enumerate(planes):\n",
    "        x[plane] = load_one_stack(case, plane=plane)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(36, 256, 256)\n",
      "255\n"
     ]
    }
   ],
   "source": [
    "case = train_abnl.Case[0]\n",
    "x = load_one_stack(case, plane='coronal')\n",
    "print(x.shape)\n",
    "print(x.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'axial': array([[[ 0,  0,  0,  0, ...,  4,  5,  4,  3],\n",
       "         [ 0,  0,  0,  0, ...,  8,  8,  6,  8],\n",
       "         [ 0,  0,  0,  0, ..., 14, 14, 11, 11],\n",
       "         [ 0,  0,  0,  0, ..., 16, 16, 14, 15],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 14, 15, 18, 16],\n",
       "         [ 0,  0,  0,  0, ..., 15, 16, 15, 12],\n",
       "         [ 0,  0,  0,  0, ..., 11, 12, 13, 12],\n",
       "         [ 0,  0,  0,  0, ...,  8, 11,  7,  9]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  4,  3,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  5,  9,  7,  7],\n",
       "         [ 0,  0,  0,  0, ..., 10, 13, 10, 10],\n",
       "         [ 0,  0,  0,  0, ..., 14, 14, 19, 17],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 18, 16, 16, 17],\n",
       "         [ 0,  0,  0,  0, ..., 13, 12, 15, 13],\n",
       "         [ 0,  0,  0,  0, ..., 16, 14, 12, 12],\n",
       "         [ 0,  0,  0,  0, ...,  8,  6,  5,  7]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  7,  8,  6,  6],\n",
       "         [ 0,  0,  0,  0, ..., 12, 11, 13, 10],\n",
       "         [ 0,  0,  0,  0, ..., 12, 18, 18, 16],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 16, 18, 16, 17],\n",
       "         [ 0,  0,  0,  0, ..., 15, 13, 13, 16],\n",
       "         [ 0,  0,  0,  0, ..., 10, 10, 10, 12],\n",
       "         [ 0,  0,  0,  0, ...,  6,  6,  6,  5]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  5,  7,  5,  4],\n",
       "         [ 0,  0,  0,  0, ..., 11, 10, 11, 12],\n",
       "         [ 0,  0,  0,  0, ..., 16, 16, 15, 14],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 17, 21, 20, 18],\n",
       "         [ 0,  0,  0,  0, ..., 14, 15, 18, 14],\n",
       "         [ 0,  0,  0,  0, ..., 11,  9,  8, 10],\n",
       "         [ 0,  0,  0,  0, ...,  5,  5,  5,  5]],\n",
       " \n",
       "        ...,\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  4,  4,  4,  4],\n",
       "         [ 0,  0,  0,  0, ..., 10,  9,  9, 10],\n",
       "         [ 0,  0,  0,  0, ..., 12, 13, 16, 14],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 16, 12, 14, 15],\n",
       "         [ 0,  0,  0,  0, ..., 11, 10, 13, 10],\n",
       "         [ 0,  0,  0,  0, ...,  9, 12,  9,  9],\n",
       "         [ 0,  0,  0,  0, ...,  6,  6,  4,  5]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  5,  5,  4,  4],\n",
       "         [ 0,  0,  0,  0, ..., 12,  8,  9, 11],\n",
       "         [ 0,  0,  0,  0, ..., 16, 17, 12, 13],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 13, 14, 14, 17],\n",
       "         [ 0,  0,  0,  0, ..., 12, 14, 16, 13],\n",
       "         [ 0,  0,  0,  0, ...,  9, 12, 12,  8],\n",
       "         [ 0,  0,  0,  0, ...,  6,  5,  4,  5]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  3,  2,  3,  2],\n",
       "         [ 0,  0,  0,  0, ...,  5,  4,  5,  6],\n",
       "         [ 0,  0,  0,  0, ..., 13, 11, 11,  9],\n",
       "         [ 0,  0,  0,  0, ..., 13, 13, 16, 16],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 17, 17, 14, 16],\n",
       "         [ 0,  0,  0,  0, ..., 16, 14, 16, 15],\n",
       "         [ 0,  0,  0,  0, ...,  9,  8, 10,  8],\n",
       "         [ 0,  0,  0,  0, ...,  6,  6,  6,  6]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  5,  4,  4,  4],\n",
       "         [ 0,  0,  0,  0, ...,  9,  8,  7,  6],\n",
       "         [ 0,  0,  0,  0, ..., 11,  9, 10, 11],\n",
       "         [ 0,  0,  0,  0, ..., 17, 15, 12, 13],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ..., 12, 16, 16, 15],\n",
       "         [ 0,  0,  0,  0, ..., 16, 12, 15, 16],\n",
       "         [ 0,  0,  0,  0, ..., 10, 14, 12, 12],\n",
       "         [ 0,  0,  0,  0, ...,  6,  8,  7,  7]]], dtype=uint8),\n",
       " 'coronal': array([[[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  3,  3,  1,  2],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  2,  3,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  1,  0]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  0],\n",
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  3,  3,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  0,  1,  1,  1]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  1,  1,  0,  1],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  3,  2,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  2,  3,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  3,  2,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1]],\n",
       " \n",
       "        ...,\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  3],\n",
       "         [ 0,  0,  0,  0, ...,  5,  4,  4,  5],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  5,  5,  4,  5],\n",
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  3],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  1],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  1,  0,  0,  1],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  4,  5,  4,  4],\n",
       "         [ 0,  0,  0,  0, ...,  7,  8,  8,  7],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  7,  7,  6,  6],\n",
       "         [ 0,  0,  0,  0, ...,  5,  3,  3,  5],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  2,  3,  3,  3],\n",
       "         [ 0,  0,  0,  0, ...,  4,  4,  5,  4],\n",
       "         [ 0,  0,  0,  0, ...,  9,  6,  8, 10],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  8,  9,  8, 11],\n",
       "         [ 0,  0,  0,  0, ...,  5,  6,  3,  4],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  1,  0],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
       "         [ 0,  0,  0,  0, ...,  3,  3,  5,  5],\n",
       "         [ 0,  0,  0,  0, ...,  9, 10,  7,  6],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  0, ...,  9,  9,  9,  8],\n",
       "         [ 0,  0,  0,  0, ...,  5,  5,  4,  4],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]]], dtype=uint8),\n",
       " 'sagittal': array([[[ 0,  0,  0,  0, ...,  2,  1,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  8,  8,  8,  7],\n",
       "         [ 0,  0,  0,  0, ..., 11, 15, 16, 14],\n",
       "         [ 7,  5,  5,  7, ..., 15, 14, 19, 16],\n",
       "         ...,\n",
       "         [ 7,  5,  7,  5, ..., 15, 17, 15, 14],\n",
       "         [ 0,  1,  1,  1, ..., 10, 15, 12, 11],\n",
       "         [ 0,  0,  0,  0, ...,  7,  7,  6,  5],\n",
       "         [ 0,  0,  0,  0, ...,  3,  2,  2,  1]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  4,  3,  1,  1],\n",
       "         [ 0,  0,  0,  0, ..., 10,  9,  7,  8],\n",
       "         [ 0,  0,  0,  0, ..., 17, 17, 15, 17],\n",
       "         [ 7,  5,  5,  6, ..., 20, 22, 21, 17],\n",
       "         ...,\n",
       "         [ 6,  7,  6,  4, ..., 18, 14, 19, 19],\n",
       "         [ 0,  1,  0,  2, ..., 16, 13, 15, 16],\n",
       "         [ 0,  0,  0,  0, ..., 10,  8,  5,  6],\n",
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  4]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  3,  4,  3,  4],\n",
       "         [ 0,  0,  0,  0, ...,  9, 12,  9, 10],\n",
       "         [ 0,  0,  0,  0, ..., 19, 20, 12, 11],\n",
       "         [ 7,  6,  7,  6, ..., 19, 17, 13, 16],\n",
       "         ...,\n",
       "         [ 3,  6,  6,  4, ..., 27, 27, 14, 19],\n",
       "         [ 2,  3,  1,  1, ..., 23, 19, 14, 15],\n",
       "         [ 0,  0,  0,  0, ...,  9,  4, 11,  9],\n",
       "         [ 0,  0,  0,  0, ...,  2,  1,  2,  3]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  3,  6,  4,  3],\n",
       "         [ 0,  0,  0,  0, ...,  8, 10,  8, 10],\n",
       "         [ 0,  0,  0,  0, ..., 21, 18, 14, 18],\n",
       "         [ 5,  9,  6,  6, ..., 26, 19, 18, 22],\n",
       "         ...,\n",
       "         [ 6,  4,  6,  9, ..., 22, 23, 24, 25],\n",
       "         [ 1,  1,  2,  1, ..., 14, 17, 19, 17],\n",
       "         [ 0,  0,  0,  0, ..., 10,  8,  7,  9],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  3]],\n",
       " \n",
       "        ...,\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  2],\n",
       "         [ 0,  0,  0,  0, ..., 10,  8, 10, 10],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  3, ...,  7, 10,  8,  7],\n",
       "         [ 0,  0,  0,  0, ...,  4,  4,  5,  4],\n",
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  2],\n",
       "         [ 0,  0,  0,  0, ..., 10, 11,  9, 11],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  4, ..., 10,  9, 10,  9],\n",
       "         [ 0,  0,  0,  1, ...,  5,  6,  5,  6],\n",
       "         [ 0,  0,  0,  0, ...,  3,  2,  1,  2],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  3,  2,  2,  2],\n",
       "         [ 0,  0,  0,  1, ..., 11, 12, 11,  9],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  3, ...,  9, 11, 10,  9],\n",
       "         [ 0,  0,  0,  1, ...,  5,  6,  6,  5],\n",
       "         [ 0,  0,  0,  0, ...,  2,  3,  2,  1],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
       " \n",
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  8, 10,  8,  9],\n",
       "         ...,\n",
       "         [ 0,  0,  0,  2, ...,  7,  9,  8,  8],\n",
       "         [ 0,  0,  0,  0, ...,  4,  4,  4,  5],\n",
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]]], dtype=uint8)}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_multi = load_stacks(case)\n",
    "x_multi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipywidgets import interactive\n",
    "from IPython.display import display\n",
    "\n",
    "plt.style.use('grayscale')\n",
    "\n",
    "class KneePlot():\n",
    "    def __init__(self, x, figsize=(10, 10)):\n",
    "        self.x = x\n",
    "        self.slice_range = (0, self.x.shape[0] - 1)\n",
    "        self.resize(figsize)\n",
    "    \n",
    "    def _plot_slice(self, im_slice):\n",
    "        fig, ax = plt.subplots(1, 1, figsize=self.figsize)\n",
    "        ax.imshow(self.x[im_slice, :, :])\n",
    "        plt.show()\n",
    "\n",
    "    def resize(self, figsize):\n",
    "        self.figsize = figsize\n",
    "        self.interactive_plot = interactive(self._plot_slice, im_slice=self.slice_range)\n",
    "        self.output = self.interactive_plot.children[-1]\n",
    "        self.output.layout.height = '{}px'.format(60 * self.figsize[1])\n",
    "\n",
    "    def show(self):\n",
    "        display(self.interactive_plot)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb94b9e5a31b44c5abd287a8cdb12fe9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=17, description='im_slice', max=35), Output(layout=Layout(height='600px'…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot = KneePlot(x)\n",
    "plot.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d55829c45294dbd90f82665f799a8ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=17, description='im_slice', max=35), Output(layout=Layout(height='720px'…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot.resize(figsize=(12, 12))\n",
    "plot.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipywidgets import interact, Dropdown, IntSlider\n",
    "\n",
    "class MultiKneePlot():\n",
    "    def __init__(self, x_multi, figsize=(10, 10)):\n",
    "        self.x = x_multi\n",
    "        self.planes = ['coronal', 'sagittal', 'axial']\n",
    "        self.slice_nums = {plane: self.x[plane].shape[0] for plane in self.planes}\n",
    "        self.figsize = figsize\n",
    "    \n",
    "    def _plot_slices(self, plane, im_slice): \n",
    "        fig, ax = plt.subplots(1, 1, figsize=self.figsize)\n",
    "        ax.imshow(self.x[plane][im_slice, :, :])\n",
    "        plt.show()\n",
    "    \n",
    "    def draw(self):\n",
    "        planes_widget = Dropdown(options=self.planes)\n",
    "        plane_init = self.planes[0]\n",
    "        slice_init = self.slice_nums[plane_init] - 1\n",
    "        slices_widget = IntSlider(min=0, max=slice_init, value=slice_init//2)\n",
    "        def update_slices_widget(*args):\n",
    "            slices_widget.max = self.slice_nums[planes_widget.value] - 1\n",
    "            slices_widget.value = slices_widget.max // 2\n",
    "        planes_widget.observe(update_slices_widget, 'value')\n",
    "        interact(self._plot_slices, plane=planes_widget, im_slice=slices_widget)\n",
    "    \n",
    "    def resize(self, figsize): self.figsize = figsize\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d7a1c151c3440c59bfa0a2b55e6180e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(Dropdown(description='plane', options=('coronal', 'sagittal', 'axial'), value='coronal')…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_multi = MultiKneePlot(x_multi)\n",
    "plot_multi.draw()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}