2252 lines (2251 with data), 140.8 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RSNA Intracranial Hemorrhage Detection "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<b>Competition Overview</b><br/><br/>\n",
"Intracranial hemorrhage, bleeding that occurs inside the cranium, is a serious health problem requiring rapid and often intensive medical treatment. For example, intracranial hemorrhages account for approximately 10% of strokes in the U.S., where stroke is the fifth-leading cause of death. Identifying the location and type of any hemorrhage present is a critical step in treating the patient.\n",
"\n",
"Diagnosis requires an urgent procedure. When a patient shows acute neurological symptoms such as severe headache or loss of consciousness, highly trained specialists review medical images of the patient’s cranium to look for the presence, location and type of hemorrhage. The process is complicated and often time consuming."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<b>What am i predicting?</b><br/><br/>\n",
"In this competition our goal is to predict intracranial hemorrhage and its subtypes. Given an image the we need to predict probablity of each subtype. This indicates its a multilabel classification problem."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<b>Competition Evaluation Metric</b><br/><br/>\n",
"Evaluation metric is weighted multi-label logarithmic loss. So for given image we need to predict probality for each subtype. There is also an any label, which indicates that a hemorrhage of ANY kind exists in the image. The any label is weighted more highly than specific hemorrhage sub-types.\n",
"\n",
"<b>Note:</b>The weights for each subtype for calculating weighted multi-label logarithmic loss is **not** given as part of the competition. We will be using binary cross entropy loss as weights are not available"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<b>Dataset Description</b>\n",
"\n",
"The dataset is divided into two parts\n",
"\n",
"1. Train\n",
"2. Test\n",
"\n",
"**1. Train**\n",
"Number of rows: 40,45,548 records.\n",
"Number of columns: 2\n",
"\n",
"Columns:\n",
"\n",
"**Id**: An image Id. Each Id corresponds to a unique image, and will contain an underscore.\n",
"\n",
"Example: ID_28fbab7eb_epidural. So the Id consists of two parts one is image file id ID_28fbab7eb and the other is sub type name\n",
"\n",
"**Label**: The target label whether that sub-type of hemorrhage (or any hemorrhage in the case of any) exists in the indicated image. 1 --> Exists and 0 --> Doesn't exist.\n",
"\n",
"**2. Test**\n",
"Number of rows: 4,71,270 records.\n",
"\n",
"Columns:\n",
"\n",
"**Id**: An image Id. Each Id corresponds to a unique image, and will contain an underscore.\n",
"\n",
"Example: ID_28fbab7eb_epidural. So the Id consists of two parts one is image file id ID_28fbab7eb and the other is sub type name"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import pydicom\n",
"import os\n",
"import glob\n",
"import random\n",
"import cv2\n",
"import tensorflow as tf\n",
"from math import ceil, floor\n",
"from tqdm import tqdm\n",
"from imgaug import augmenters as iaa\n",
"import matplotlib.pyplot as plt\n",
"from math import ceil, floor\n",
"import keras\n",
"import keras.backend as K\n",
"from keras.callbacks import Callback, ModelCheckpoint\n",
"from keras.layers import Dense, Flatten, Dropout\n",
"from keras.models import Model, load_model\n",
"from keras.utils import Sequence\n",
"from keras.losses import binary_crossentropy\n",
"from keras.optimizers import Adam"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Random Seed\n",
"SEED = 42\n",
"np.random.seed(SEED)\n",
"\n",
"# some constants\n",
"TEST_SIZE = 0.06\n",
"HEIGHT = 256\n",
"WIDTH = 256\n",
"TRAIN_BATCH_SIZE = 32\n",
"VALID_BATCH_SIZE = 64\n",
"\n",
"# Train and Test folders\n",
"input_folder = '../input/rsna-intracranial-hemorrhage-detection/'\n",
"path_train_img = input_folder + 'stage_1_train_images/'\n",
"path_test_img = input_folder + 'stage_1_test_images/'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>Label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ID_63eb1e259_epidural</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ID_63eb1e259_intraparenchymal</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ID_63eb1e259_intraventricular</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ID_63eb1e259_subarachnoid</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ID_63eb1e259_subdural</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" ID Label\n",
"0 ID_63eb1e259_epidural 0\n",
"1 ID_63eb1e259_intraparenchymal 0\n",
"2 ID_63eb1e259_intraventricular 0\n",
"3 ID_63eb1e259_subarachnoid 0\n",
"4 ID_63eb1e259_subdural 0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df = pd.read_csv(input_folder + 'stage_1_train.csv')\n",
"train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>Label</th>\n",
" <th>sub_type</th>\n",
" <th>file_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ID_63eb1e259_epidural</td>\n",
" <td>0</td>\n",
" <td>epidural</td>\n",
" <td>ID_63eb1e259.dcm</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ID_63eb1e259_intraparenchymal</td>\n",
" <td>0</td>\n",
" <td>intraparenchymal</td>\n",
" <td>ID_63eb1e259.dcm</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ID_63eb1e259_intraventricular</td>\n",
" <td>0</td>\n",
" <td>intraventricular</td>\n",
" <td>ID_63eb1e259.dcm</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ID_63eb1e259_subarachnoid</td>\n",
" <td>0</td>\n",
" <td>subarachnoid</td>\n",
" <td>ID_63eb1e259.dcm</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ID_63eb1e259_subdural</td>\n",
" <td>0</td>\n",
" <td>subdural</td>\n",
" <td>ID_63eb1e259.dcm</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" ID Label sub_type file_name\n",
"0 ID_63eb1e259_epidural 0 epidural ID_63eb1e259.dcm\n",
"1 ID_63eb1e259_intraparenchymal 0 intraparenchymal ID_63eb1e259.dcm\n",
"2 ID_63eb1e259_intraventricular 0 intraventricular ID_63eb1e259.dcm\n",
"3 ID_63eb1e259_subarachnoid 0 subarachnoid ID_63eb1e259.dcm\n",
"4 ID_63eb1e259_subdural 0 subdural ID_63eb1e259.dcm"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extract subtype\n",
"train_df['sub_type'] = train_df['ID'].apply(lambda x: x.split('_')[-1])\n",
"# extract filename\n",
"train_df['file_name'] = train_df['ID'].apply(lambda x: '_'.join(x.split('_')[:2]) + '.dcm')\n",
"train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4045572, 4)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4045548, 4)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# remove duplicates\n",
"train_df.drop_duplicates(['Label', 'sub_type', 'file_name'], inplace=True)\n",
"train_df.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of train images availabe: 674258\n"
]
}
],
"source": [
"print(\"Number of train images availabe:\", len(os.listdir(path_train_img)))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>sub_type</th>\n",
" <th>any</th>\n",
" <th>epidural</th>\n",
" <th>intraparenchymal</th>\n",
" <th>intraventricular</th>\n",
" <th>subarachnoid</th>\n",
" <th>subdural</th>\n",
" </tr>\n",
" <tr>\n",
" <th>file_name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>ID_000039fa0.dcm</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_00005679d.dcm</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_00008ce3c.dcm</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_0000950d7.dcm</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_0000aee4b.dcm</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"sub_type any epidural intraparenchymal intraventricular \\\n",
"file_name \n",
"ID_000039fa0.dcm 0 0 0 0 \n",
"ID_00005679d.dcm 0 0 0 0 \n",
"ID_00008ce3c.dcm 0 0 0 0 \n",
"ID_0000950d7.dcm 0 0 0 0 \n",
"ID_0000aee4b.dcm 0 0 0 0 \n",
"\n",
"sub_type subarachnoid subdural \n",
"file_name \n",
"ID_000039fa0.dcm 0 0 \n",
"ID_00005679d.dcm 0 0 \n",
"ID_00008ce3c.dcm 0 0 \n",
"ID_0000950d7.dcm 0 0 \n",
"ID_0000aee4b.dcm 0 0 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_final_df = pd.pivot_table(train_df.drop(columns='ID'), index=\"file_name\", \\\n",
" columns=\"sub_type\", values=\"Label\")\n",
"train_final_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(674258, 6)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_final_df.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Invalid image ID_6431af929.dcm\n",
"train_final_df.drop('ID_6431af929.dcm', inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting efficientnet\n",
" Downloading https://files.pythonhosted.org/packages/97/82/f3ae07316f0461417dc54affab6e86ab188a5a22f33176d35271628b96e0/efficientnet-1.0.0-py3-none-any.whl\n",
"Requirement already satisfied: keras-applications<=1.0.8,>=1.0.7 in /opt/conda/lib/python3.6/site-packages (from efficientnet) (1.0.8)\n",
"Requirement already satisfied: scikit-image in /opt/conda/lib/python3.6/site-packages (from efficientnet) (0.16.1)\n",
"Requirement already satisfied: numpy>=1.9.1 in /opt/conda/lib/python3.6/site-packages (from keras-applications<=1.0.8,>=1.0.7->efficientnet) (1.16.4)\n",
"Requirement already satisfied: h5py in /opt/conda/lib/python3.6/site-packages (from keras-applications<=1.0.8,>=1.0.7->efficientnet) (2.9.0)\n",
"Requirement already satisfied: PyWavelets>=0.4.0 in /opt/conda/lib/python3.6/site-packages (from scikit-image->efficientnet) (1.0.3)\n",
"Requirement already satisfied: scipy>=0.19.0 in /opt/conda/lib/python3.6/site-packages (from scikit-image->efficientnet) (1.2.1)\n",
"Requirement already satisfied: networkx>=2.0 in /opt/conda/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.4)\n",
"Requirement already satisfied: imageio>=2.3.0 in /opt/conda/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.6.0)\n",
"Requirement already satisfied: pillow>=4.3.0 in /opt/conda/lib/python3.6/site-packages (from scikit-image->efficientnet) (5.4.1)\n",
"Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /opt/conda/lib/python3.6/site-packages (from scikit-image->efficientnet) (3.0.3)\n",
"Requirement already satisfied: six in /opt/conda/lib/python3.6/site-packages (from h5py->keras-applications<=1.0.8,>=1.0.7->efficientnet) (1.12.0)\n",
"Requirement already satisfied: decorator>=4.3.0 in /opt/conda/lib/python3.6/site-packages (from networkx>=2.0->scikit-image->efficientnet) (4.4.0)\n",
"Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (1.1.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2.4.2)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2.8.0)\n",
"Requirement already satisfied: setuptools in /opt/conda/lib/python3.6/site-packages (from kiwisolver>=1.0.1->matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (41.4.0)\n",
"Installing collected packages: efficientnet\n",
"Successfully installed efficientnet-1.0.0\n",
"Collecting iterative-stratification\n",
" Downloading https://files.pythonhosted.org/packages/9d/79/9ba64c8c07b07b8b45d80725b2ebd7b7884701c1da34f70d4749f7b45f9a/iterative_stratification-0.1.6-py3-none-any.whl\n",
"Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.6/site-packages (from iterative-stratification) (0.21.3)\n",
"Requirement already satisfied: scipy in /opt/conda/lib/python3.6/site-packages (from iterative-stratification) (1.2.1)\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.6/site-packages (from iterative-stratification) (1.16.4)\n",
"Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.6/site-packages (from scikit-learn->iterative-stratification) (0.13.2)\n",
"Installing collected packages: iterative-stratification\n",
"Successfully installed iterative-stratification-0.1.6\n"
]
}
],
"source": [
"# Install Efficient Net as it is not part of Keras\n",
"!pip install efficientnet\n",
"!pip install iterative-stratification"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import efficientnet.keras as efn \n",
"from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import HTML\n",
"\n",
"def create_download_link(title = \"Download CSV file\", filename = \"data.csv\"): \n",
" \"\"\"\n",
" Helper function to generate download link to files in kaggle kernel \n",
" \"\"\"\n",
" html = '<a href={filename}>{title}</a>'\n",
" html = html.format(title=title,filename=filename)\n",
" return HTML(html)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def get_dicom_field_value(val):\n",
" \"\"\"\n",
" Helper function to get value of dicom field in dicom file\n",
" \"\"\"\n",
" if type(val) == pydicom.multival.MultiValue:\n",
" return int(val[0])\n",
" else:\n",
" return int(val)\n",
"\n",
"def get_windowing(data):\n",
" \"\"\"\n",
" Helper function to extract meta data features in dicom file\n",
" return: window center, window width, slope, intercept\n",
" \"\"\"\n",
" dicom_fields = [data.WindowCenter, data.WindowWidth, data.RescaleSlope, data.RescaleIntercept]\n",
" return [get_dicom_field_value(x) for x in dicom_fields]\n",
"\n",
"\n",
"def get_windowed_image(image, wc, ww, slope, intercept):\n",
" \"\"\"\n",
" Helper function to construct windowed image from meta data features\n",
" return: windowed image\n",
" \"\"\"\n",
" img = (image*slope +intercept)\n",
" img_min = wc - ww//2\n",
" img_max = wc + ww//2\n",
" img[img<img_min] = img_min\n",
" img[img>img_max] = img_max\n",
" return img \n",
"\n",
"\n",
"def _normalize(img):\n",
" if img.max() == img.min():\n",
" return np.zeros(img.shape)\n",
" return 2 * (img - img.min())/(img.max() - img.min()) - 1\n",
"\n",
"def _read(path, desired_size=(224, 224)):\n",
" \"\"\"\n",
" Helper function to generate windowed image \n",
" \"\"\"\n",
" # 1. read dicom file\n",
" dcm = pydicom.dcmread(path)\n",
" \n",
" # 2. Extract meta data features\n",
" # window center, window width, slope, intercept\n",
" window_params = get_windowing(dcm)\n",
"\n",
" try:\n",
" # 3. Generate windowed image\n",
" img = get_windowed_image(dcm.pixel_array, *window_params)\n",
" except:\n",
" img = np.zeros(desired_size)\n",
"\n",
" img = _normalize(img)\n",
"\n",
" if desired_size != (512, 512):\n",
" # resize image\n",
" img = cv2.resize(img, desired_size, interpolation = cv2.INTER_LINEAR)\n",
" return img[:,:,np.newaxis]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(128, 128, 1)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_read(path_train_img + 'ID_ffff922b9.dcm', (128, 128)).shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f821e7b95c0>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(\n",
" _read(path_train_img + 'ID_ffff922b9.dcm', (128, 128))[:, :, 0]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Augmentations\n",
"# Flip Left Right\n",
"# Cropping\n",
"sometimes = lambda aug: iaa.Sometimes(0.25, aug)\n",
"augmentation = iaa.Sequential([ \n",
" iaa.Fliplr(0.25),\n",
" sometimes(iaa.Crop(px=(0, 25), keep_size = True, \n",
" sample_independently = False)) \n",
" ], random_order = True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Train Data Generator\n",
"class TrainDataGenerator(keras.utils.Sequence):\n",
"\n",
" def __init__(self, dataset, labels, batch_size=16, img_size=(512, 512), img_dir = path_train_img, \\\n",
" augment = False, *args, **kwargs):\n",
" self.dataset = dataset\n",
" self.ids = dataset.index\n",
" self.labels = labels\n",
" self.batch_size = batch_size\n",
" self.img_size = img_size\n",
" self.img_dir = img_dir\n",
" self.augment = augment\n",
" self.on_epoch_end()\n",
"\n",
" def __len__(self):\n",
" return int(ceil(len(self.ids) / self.batch_size))\n",
"\n",
" def __getitem__(self, index):\n",
" indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]\n",
" X, Y = self.__data_generation(indices)\n",
" return X, Y\n",
"\n",
" def augmentor(self, image):\n",
" augment_img = augmentation \n",
" image_aug = augment_img.augment_image(image)\n",
" return image_aug\n",
"\n",
" def on_epoch_end(self):\n",
" self.indices = np.arange(len(self.ids))\n",
" np.random.shuffle(self.indices)\n",
" \n",
" def __data_generation(self, indices):\n",
" X = np.empty((self.batch_size, *self.img_size, 3))\n",
" Y = np.empty((self.batch_size, 6), dtype=np.float32)\n",
" \n",
" for i, index in enumerate(indices):\n",
" ID = self.ids[index]\n",
" image = _read(self.img_dir + ID, self.img_size)\n",
" if self.augment:\n",
" X[i,] = self.augmentor(image)\n",
" else:\n",
" X[i,] = image \n",
" Y[i,] = self.labels.iloc[index].values \n",
" return X, Y\n",
" \n",
"class TestDataGenerator(keras.utils.Sequence):\n",
" def __init__(self, ids, labels, batch_size = 5, img_size = (512, 512), img_dir = path_test_img, \\\n",
" *args, **kwargs):\n",
" self.ids = ids\n",
" self.labels = labels\n",
" self.batch_size = batch_size\n",
" self.img_size = img_size\n",
" self.img_dir = img_dir\n",
" self.on_epoch_end()\n",
"\n",
" def __len__(self):\n",
" return int(ceil(len(self.ids) / self.batch_size))\n",
"\n",
" def __getitem__(self, index):\n",
" indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]\n",
" list_IDs_temp = [self.ids[k] for k in indices]\n",
" X = self.__data_generation(list_IDs_temp)\n",
" return X\n",
"\n",
" def on_epoch_end(self):\n",
" self.indices = np.arange(len(self.ids))\n",
"\n",
" def __data_generation(self, list_IDs_temp):\n",
" X = np.empty((self.batch_size, *self.img_size, 3))\n",
" for i, ID in enumerate(list_IDs_temp):\n",
" image = _read(self.img_dir + ID, self.img_size)\n",
" X[i,] = image \n",
" return X"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we have seen in EDA notebook that we have very few epidural subtypes so we need oversample this sub type"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Shape: (677018, 6)\n"
]
}
],
"source": [
"# Oversampling\n",
"epidural_df = train_final_df[train_final_df.epidural == 1]\n",
"train_final_df = pd.concat([train_final_df, epidural_df])\n",
"print('Train Shape: {}'.format(train_final_df.shape))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>Label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ID_28fbab7eb_epidural</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ID_28fbab7eb_intraparenchymal</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ID_28fbab7eb_intraventricular</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ID_28fbab7eb_subarachnoid</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ID_28fbab7eb_subdural</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" ID Label\n",
"0 ID_28fbab7eb_epidural 0.5\n",
"1 ID_28fbab7eb_intraparenchymal 0.5\n",
"2 ID_28fbab7eb_intraventricular 0.5\n",
"3 ID_28fbab7eb_subarachnoid 0.5\n",
"4 ID_28fbab7eb_subdural 0.5"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load test set\n",
"test_df = pd.read_csv(input_folder + 'stage_1_sample_submission.csv')\n",
"test_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(78545, 6)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extract subtype\n",
"test_df['sub_type'] = test_df['ID'].apply(lambda x: x.split('_')[-1])\n",
"# extract filename\n",
"test_df['file_name'] = test_df['ID'].apply(lambda x: '_'.join(x.split('_')[:2]) + '.dcm')\n",
"\n",
"test_df = pd.pivot_table(test_df.drop(columns='ID'), index=\"file_name\", \\\n",
" columns=\"sub_type\", values=\"Label\")\n",
"test_df.head()\n",
"\n",
"test_df.shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>sub_type</th>\n",
" <th>any</th>\n",
" <th>epidural</th>\n",
" <th>intraparenchymal</th>\n",
" <th>intraventricular</th>\n",
" <th>subarachnoid</th>\n",
" <th>subdural</th>\n",
" </tr>\n",
" <tr>\n",
" <th>file_name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>ID_000012eaf.dcm</th>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_0000ca2f6.dcm</th>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_000259ccf.dcm</th>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_0002d438a.dcm</th>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ID_00032d440.dcm</th>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"sub_type any epidural intraparenchymal intraventricular \\\n",
"file_name \n",
"ID_000012eaf.dcm 0.5 0.5 0.5 0.5 \n",
"ID_0000ca2f6.dcm 0.5 0.5 0.5 0.5 \n",
"ID_000259ccf.dcm 0.5 0.5 0.5 0.5 \n",
"ID_0002d438a.dcm 0.5 0.5 0.5 0.5 \n",
"ID_00032d440.dcm 0.5 0.5 0.5 0.5 \n",
"\n",
"sub_type subarachnoid subdural \n",
"file_name \n",
"ID_000012eaf.dcm 0.5 0.5 \n",
"ID_0000ca2f6.dcm 0.5 0.5 \n",
"ID_000259ccf.dcm 0.5 0.5 \n",
"ID_0002d438a.dcm 0.5 0.5 \n",
"ID_00032d440.dcm 0.5 0.5 "
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://github.com/Callidior/keras-applications/releases/download/efficientnet/efficientnet-b0_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5\n",
"16809984/16804768 [==============================] - 1s 0us/step\n",
"Model: \"model_1\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_1 (InputLayer) (None, 256, 256, 3) 0 \n",
"__________________________________________________________________________________________________\n",
"stem_conv (Conv2D) (None, 128, 128, 32) 864 input_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"stem_bn (BatchNormalization) (None, 128, 128, 32) 128 stem_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"stem_activation (Activation) (None, 128, 128, 32) 0 stem_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_dwconv (DepthwiseConv2D (None, 128, 128, 32) 288 stem_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_bn (BatchNormalization) (None, 128, 128, 32) 128 block1a_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_activation (Activation) (None, 128, 128, 32) 0 block1a_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_se_squeeze (GlobalAvera (None, 32) 0 block1a_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_se_reshape (Reshape) (None, 1, 1, 32) 0 block1a_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_se_reduce (Conv2D) (None, 1, 1, 8) 264 block1a_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_se_expand (Conv2D) (None, 1, 1, 32) 288 block1a_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_se_excite (Multiply) (None, 128, 128, 32) 0 block1a_activation[0][0] \n",
" block1a_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_project_conv (Conv2D) (None, 128, 128, 16) 512 block1a_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block1a_project_bn (BatchNormal (None, 128, 128, 16) 64 block1a_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_expand_conv (Conv2D) (None, 128, 128, 96) 1536 block1a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_expand_bn (BatchNormali (None, 128, 128, 96) 384 block2a_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_expand_activation (Acti (None, 128, 128, 96) 0 block2a_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_dwconv (DepthwiseConv2D (None, 64, 64, 96) 864 block2a_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_bn (BatchNormalization) (None, 64, 64, 96) 384 block2a_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_activation (Activation) (None, 64, 64, 96) 0 block2a_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_se_squeeze (GlobalAvera (None, 96) 0 block2a_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_se_reshape (Reshape) (None, 1, 1, 96) 0 block2a_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_se_reduce (Conv2D) (None, 1, 1, 4) 388 block2a_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_se_expand (Conv2D) (None, 1, 1, 96) 480 block2a_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_se_excite (Multiply) (None, 64, 64, 96) 0 block2a_activation[0][0] \n",
" block2a_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_project_conv (Conv2D) (None, 64, 64, 24) 2304 block2a_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2a_project_bn (BatchNormal (None, 64, 64, 24) 96 block2a_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_expand_conv (Conv2D) (None, 64, 64, 144) 3456 block2a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_expand_bn (BatchNormali (None, 64, 64, 144) 576 block2b_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_expand_activation (Acti (None, 64, 64, 144) 0 block2b_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_dwconv (DepthwiseConv2D (None, 64, 64, 144) 1296 block2b_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_bn (BatchNormalization) (None, 64, 64, 144) 576 block2b_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_activation (Activation) (None, 64, 64, 144) 0 block2b_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_se_squeeze (GlobalAvera (None, 144) 0 block2b_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_se_reshape (Reshape) (None, 1, 1, 144) 0 block2b_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_se_reduce (Conv2D) (None, 1, 1, 6) 870 block2b_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_se_expand (Conv2D) (None, 1, 1, 144) 1008 block2b_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_se_excite (Multiply) (None, 64, 64, 144) 0 block2b_activation[0][0] \n",
" block2b_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_project_conv (Conv2D) (None, 64, 64, 24) 3456 block2b_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_project_bn (BatchNormal (None, 64, 64, 24) 96 block2b_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_drop (FixedDropout) (None, 64, 64, 24) 0 block2b_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block2b_add (Add) (None, 64, 64, 24) 0 block2b_drop[0][0] \n",
" block2a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_expand_conv (Conv2D) (None, 64, 64, 144) 3456 block2b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_expand_bn (BatchNormali (None, 64, 64, 144) 576 block3a_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_expand_activation (Acti (None, 64, 64, 144) 0 block3a_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_dwconv (DepthwiseConv2D (None, 32, 32, 144) 3600 block3a_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_bn (BatchNormalization) (None, 32, 32, 144) 576 block3a_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_activation (Activation) (None, 32, 32, 144) 0 block3a_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_se_squeeze (GlobalAvera (None, 144) 0 block3a_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_se_reshape (Reshape) (None, 1, 1, 144) 0 block3a_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_se_reduce (Conv2D) (None, 1, 1, 6) 870 block3a_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_se_expand (Conv2D) (None, 1, 1, 144) 1008 block3a_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_se_excite (Multiply) (None, 32, 32, 144) 0 block3a_activation[0][0] \n",
" block3a_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_project_conv (Conv2D) (None, 32, 32, 40) 5760 block3a_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3a_project_bn (BatchNormal (None, 32, 32, 40) 160 block3a_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_expand_conv (Conv2D) (None, 32, 32, 240) 9600 block3a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_expand_bn (BatchNormali (None, 32, 32, 240) 960 block3b_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_expand_activation (Acti (None, 32, 32, 240) 0 block3b_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_dwconv (DepthwiseConv2D (None, 32, 32, 240) 6000 block3b_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_bn (BatchNormalization) (None, 32, 32, 240) 960 block3b_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_activation (Activation) (None, 32, 32, 240) 0 block3b_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_se_squeeze (GlobalAvera (None, 240) 0 block3b_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_se_reshape (Reshape) (None, 1, 1, 240) 0 block3b_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_se_reduce (Conv2D) (None, 1, 1, 10) 2410 block3b_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_se_expand (Conv2D) (None, 1, 1, 240) 2640 block3b_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_se_excite (Multiply) (None, 32, 32, 240) 0 block3b_activation[0][0] \n",
" block3b_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_project_conv (Conv2D) (None, 32, 32, 40) 9600 block3b_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_project_bn (BatchNormal (None, 32, 32, 40) 160 block3b_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_drop (FixedDropout) (None, 32, 32, 40) 0 block3b_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block3b_add (Add) (None, 32, 32, 40) 0 block3b_drop[0][0] \n",
" block3a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_expand_conv (Conv2D) (None, 32, 32, 240) 9600 block3b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_expand_bn (BatchNormali (None, 32, 32, 240) 960 block4a_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_expand_activation (Acti (None, 32, 32, 240) 0 block4a_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_dwconv (DepthwiseConv2D (None, 16, 16, 240) 2160 block4a_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_bn (BatchNormalization) (None, 16, 16, 240) 960 block4a_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_activation (Activation) (None, 16, 16, 240) 0 block4a_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_se_squeeze (GlobalAvera (None, 240) 0 block4a_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_se_reshape (Reshape) (None, 1, 1, 240) 0 block4a_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_se_reduce (Conv2D) (None, 1, 1, 10) 2410 block4a_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_se_expand (Conv2D) (None, 1, 1, 240) 2640 block4a_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_se_excite (Multiply) (None, 16, 16, 240) 0 block4a_activation[0][0] \n",
" block4a_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_project_conv (Conv2D) (None, 16, 16, 80) 19200 block4a_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4a_project_bn (BatchNormal (None, 16, 16, 80) 320 block4a_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_expand_conv (Conv2D) (None, 16, 16, 480) 38400 block4a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_expand_bn (BatchNormali (None, 16, 16, 480) 1920 block4b_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_expand_activation (Acti (None, 16, 16, 480) 0 block4b_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_dwconv (DepthwiseConv2D (None, 16, 16, 480) 4320 block4b_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_bn (BatchNormalization) (None, 16, 16, 480) 1920 block4b_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_activation (Activation) (None, 16, 16, 480) 0 block4b_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_se_squeeze (GlobalAvera (None, 480) 0 block4b_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_se_reshape (Reshape) (None, 1, 1, 480) 0 block4b_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_se_reduce (Conv2D) (None, 1, 1, 20) 9620 block4b_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_se_expand (Conv2D) (None, 1, 1, 480) 10080 block4b_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_se_excite (Multiply) (None, 16, 16, 480) 0 block4b_activation[0][0] \n",
" block4b_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_project_conv (Conv2D) (None, 16, 16, 80) 38400 block4b_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_project_bn (BatchNormal (None, 16, 16, 80) 320 block4b_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_drop (FixedDropout) (None, 16, 16, 80) 0 block4b_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4b_add (Add) (None, 16, 16, 80) 0 block4b_drop[0][0] \n",
" block4a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_expand_conv (Conv2D) (None, 16, 16, 480) 38400 block4b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_expand_bn (BatchNormali (None, 16, 16, 480) 1920 block4c_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_expand_activation (Acti (None, 16, 16, 480) 0 block4c_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_dwconv (DepthwiseConv2D (None, 16, 16, 480) 4320 block4c_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_bn (BatchNormalization) (None, 16, 16, 480) 1920 block4c_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_activation (Activation) (None, 16, 16, 480) 0 block4c_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_se_squeeze (GlobalAvera (None, 480) 0 block4c_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_se_reshape (Reshape) (None, 1, 1, 480) 0 block4c_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_se_reduce (Conv2D) (None, 1, 1, 20) 9620 block4c_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_se_expand (Conv2D) (None, 1, 1, 480) 10080 block4c_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_se_excite (Multiply) (None, 16, 16, 480) 0 block4c_activation[0][0] \n",
" block4c_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_project_conv (Conv2D) (None, 16, 16, 80) 38400 block4c_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_project_bn (BatchNormal (None, 16, 16, 80) 320 block4c_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_drop (FixedDropout) (None, 16, 16, 80) 0 block4c_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block4c_add (Add) (None, 16, 16, 80) 0 block4c_drop[0][0] \n",
" block4b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_expand_conv (Conv2D) (None, 16, 16, 480) 38400 block4c_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_expand_bn (BatchNormali (None, 16, 16, 480) 1920 block5a_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_expand_activation (Acti (None, 16, 16, 480) 0 block5a_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_dwconv (DepthwiseConv2D (None, 16, 16, 480) 12000 block5a_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_bn (BatchNormalization) (None, 16, 16, 480) 1920 block5a_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_activation (Activation) (None, 16, 16, 480) 0 block5a_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_se_squeeze (GlobalAvera (None, 480) 0 block5a_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_se_reshape (Reshape) (None, 1, 1, 480) 0 block5a_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_se_reduce (Conv2D) (None, 1, 1, 20) 9620 block5a_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_se_expand (Conv2D) (None, 1, 1, 480) 10080 block5a_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_se_excite (Multiply) (None, 16, 16, 480) 0 block5a_activation[0][0] \n",
" block5a_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_project_conv (Conv2D) (None, 16, 16, 112) 53760 block5a_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5a_project_bn (BatchNormal (None, 16, 16, 112) 448 block5a_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_expand_conv (Conv2D) (None, 16, 16, 672) 75264 block5a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_expand_bn (BatchNormali (None, 16, 16, 672) 2688 block5b_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_expand_activation (Acti (None, 16, 16, 672) 0 block5b_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_dwconv (DepthwiseConv2D (None, 16, 16, 672) 16800 block5b_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_bn (BatchNormalization) (None, 16, 16, 672) 2688 block5b_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_activation (Activation) (None, 16, 16, 672) 0 block5b_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_se_squeeze (GlobalAvera (None, 672) 0 block5b_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_se_reshape (Reshape) (None, 1, 1, 672) 0 block5b_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_se_reduce (Conv2D) (None, 1, 1, 28) 18844 block5b_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_se_expand (Conv2D) (None, 1, 1, 672) 19488 block5b_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_se_excite (Multiply) (None, 16, 16, 672) 0 block5b_activation[0][0] \n",
" block5b_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_project_conv (Conv2D) (None, 16, 16, 112) 75264 block5b_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_project_bn (BatchNormal (None, 16, 16, 112) 448 block5b_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_drop (FixedDropout) (None, 16, 16, 112) 0 block5b_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5b_add (Add) (None, 16, 16, 112) 0 block5b_drop[0][0] \n",
" block5a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_expand_conv (Conv2D) (None, 16, 16, 672) 75264 block5b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_expand_bn (BatchNormali (None, 16, 16, 672) 2688 block5c_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_expand_activation (Acti (None, 16, 16, 672) 0 block5c_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_dwconv (DepthwiseConv2D (None, 16, 16, 672) 16800 block5c_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_bn (BatchNormalization) (None, 16, 16, 672) 2688 block5c_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_activation (Activation) (None, 16, 16, 672) 0 block5c_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_se_squeeze (GlobalAvera (None, 672) 0 block5c_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_se_reshape (Reshape) (None, 1, 1, 672) 0 block5c_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_se_reduce (Conv2D) (None, 1, 1, 28) 18844 block5c_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_se_expand (Conv2D) (None, 1, 1, 672) 19488 block5c_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_se_excite (Multiply) (None, 16, 16, 672) 0 block5c_activation[0][0] \n",
" block5c_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_project_conv (Conv2D) (None, 16, 16, 112) 75264 block5c_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_project_bn (BatchNormal (None, 16, 16, 112) 448 block5c_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_drop (FixedDropout) (None, 16, 16, 112) 0 block5c_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block5c_add (Add) (None, 16, 16, 112) 0 block5c_drop[0][0] \n",
" block5b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_expand_conv (Conv2D) (None, 16, 16, 672) 75264 block5c_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_expand_bn (BatchNormali (None, 16, 16, 672) 2688 block6a_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_expand_activation (Acti (None, 16, 16, 672) 0 block6a_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_dwconv (DepthwiseConv2D (None, 8, 8, 672) 16800 block6a_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_bn (BatchNormalization) (None, 8, 8, 672) 2688 block6a_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_activation (Activation) (None, 8, 8, 672) 0 block6a_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_se_squeeze (GlobalAvera (None, 672) 0 block6a_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_se_reshape (Reshape) (None, 1, 1, 672) 0 block6a_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_se_reduce (Conv2D) (None, 1, 1, 28) 18844 block6a_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_se_expand (Conv2D) (None, 1, 1, 672) 19488 block6a_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_se_excite (Multiply) (None, 8, 8, 672) 0 block6a_activation[0][0] \n",
" block6a_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_project_conv (Conv2D) (None, 8, 8, 192) 129024 block6a_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6a_project_bn (BatchNormal (None, 8, 8, 192) 768 block6a_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_expand_conv (Conv2D) (None, 8, 8, 1152) 221184 block6a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_expand_bn (BatchNormali (None, 8, 8, 1152) 4608 block6b_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_expand_activation (Acti (None, 8, 8, 1152) 0 block6b_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_dwconv (DepthwiseConv2D (None, 8, 8, 1152) 28800 block6b_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_bn (BatchNormalization) (None, 8, 8, 1152) 4608 block6b_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_activation (Activation) (None, 8, 8, 1152) 0 block6b_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_se_squeeze (GlobalAvera (None, 1152) 0 block6b_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_se_reshape (Reshape) (None, 1, 1, 1152) 0 block6b_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block6b_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block6b_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_se_excite (Multiply) (None, 8, 8, 1152) 0 block6b_activation[0][0] \n",
" block6b_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_project_conv (Conv2D) (None, 8, 8, 192) 221184 block6b_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_project_bn (BatchNormal (None, 8, 8, 192) 768 block6b_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_drop (FixedDropout) (None, 8, 8, 192) 0 block6b_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6b_add (Add) (None, 8, 8, 192) 0 block6b_drop[0][0] \n",
" block6a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_expand_conv (Conv2D) (None, 8, 8, 1152) 221184 block6b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_expand_bn (BatchNormali (None, 8, 8, 1152) 4608 block6c_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_expand_activation (Acti (None, 8, 8, 1152) 0 block6c_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_dwconv (DepthwiseConv2D (None, 8, 8, 1152) 28800 block6c_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_bn (BatchNormalization) (None, 8, 8, 1152) 4608 block6c_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_activation (Activation) (None, 8, 8, 1152) 0 block6c_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_se_squeeze (GlobalAvera (None, 1152) 0 block6c_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_se_reshape (Reshape) (None, 1, 1, 1152) 0 block6c_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block6c_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block6c_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_se_excite (Multiply) (None, 8, 8, 1152) 0 block6c_activation[0][0] \n",
" block6c_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_project_conv (Conv2D) (None, 8, 8, 192) 221184 block6c_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_project_bn (BatchNormal (None, 8, 8, 192) 768 block6c_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_drop (FixedDropout) (None, 8, 8, 192) 0 block6c_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6c_add (Add) (None, 8, 8, 192) 0 block6c_drop[0][0] \n",
" block6b_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_expand_conv (Conv2D) (None, 8, 8, 1152) 221184 block6c_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_expand_bn (BatchNormali (None, 8, 8, 1152) 4608 block6d_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_expand_activation (Acti (None, 8, 8, 1152) 0 block6d_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_dwconv (DepthwiseConv2D (None, 8, 8, 1152) 28800 block6d_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_bn (BatchNormalization) (None, 8, 8, 1152) 4608 block6d_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_activation (Activation) (None, 8, 8, 1152) 0 block6d_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_se_squeeze (GlobalAvera (None, 1152) 0 block6d_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_se_reshape (Reshape) (None, 1, 1, 1152) 0 block6d_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block6d_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block6d_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_se_excite (Multiply) (None, 8, 8, 1152) 0 block6d_activation[0][0] \n",
" block6d_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_project_conv (Conv2D) (None, 8, 8, 192) 221184 block6d_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_project_bn (BatchNormal (None, 8, 8, 192) 768 block6d_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_drop (FixedDropout) (None, 8, 8, 192) 0 block6d_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block6d_add (Add) (None, 8, 8, 192) 0 block6d_drop[0][0] \n",
" block6c_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_expand_conv (Conv2D) (None, 8, 8, 1152) 221184 block6d_add[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_expand_bn (BatchNormali (None, 8, 8, 1152) 4608 block7a_expand_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_expand_activation (Acti (None, 8, 8, 1152) 0 block7a_expand_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_dwconv (DepthwiseConv2D (None, 8, 8, 1152) 10368 block7a_expand_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_bn (BatchNormalization) (None, 8, 8, 1152) 4608 block7a_dwconv[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_activation (Activation) (None, 8, 8, 1152) 0 block7a_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_se_squeeze (GlobalAvera (None, 1152) 0 block7a_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_se_reshape (Reshape) (None, 1, 1, 1152) 0 block7a_se_squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block7a_se_reshape[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block7a_se_reduce[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_se_excite (Multiply) (None, 8, 8, 1152) 0 block7a_activation[0][0] \n",
" block7a_se_expand[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_project_conv (Conv2D) (None, 8, 8, 320) 368640 block7a_se_excite[0][0] \n",
"__________________________________________________________________________________________________\n",
"block7a_project_bn (BatchNormal (None, 8, 8, 320) 1280 block7a_project_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"top_conv (Conv2D) (None, 8, 8, 1280) 409600 block7a_project_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"top_bn (BatchNormalization) (None, 8, 8, 1280) 5120 top_conv[0][0] \n",
"__________________________________________________________________________________________________\n",
"top_activation (Activation) (None, 8, 8, 1280) 0 top_bn[0][0] \n",
"__________________________________________________________________________________________________\n",
"avg_pool (GlobalAveragePooling2 (None, 1280) 0 top_activation[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_1 (Dropout) (None, 1280) 0 avg_pool[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_1 (Dense) (None, 6) 7686 dropout_1[0][0] \n",
"==================================================================================================\n",
"Total params: 4,057,250\n",
"Trainable params: 4,015,234\n",
"Non-trainable params: 42,016\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"base_model = efn.EfficientNetB0(weights = 'imagenet', include_top = False, \\\n",
" pooling = 'avg', input_shape = (HEIGHT, WIDTH, 3))\n",
"x = base_model.output\n",
"x = Dropout(0.125)(x)\n",
"output_layer = Dense(6, activation = 'sigmoid')(x)\n",
"model = Model(inputs=base_model.input, outputs=output_layer)\n",
"model.compile(optimizer = Adam(learning_rate = 0.0001), \n",
" loss = 'binary_crossentropy',\n",
" metrics = ['acc', tf.keras.metrics.AUC()])\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(636396, 40622)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# https://github.com/trent-b/iterative-stratification\n",
"# Mutlilabel stratification\n",
"splits = MultilabelStratifiedShuffleSplit(n_splits = 2, test_size = TEST_SIZE, random_state = SEED)\n",
"file_names = train_final_df.index\n",
"labels = train_final_df.values\n",
"# Lets take only the first split\n",
"split = next(splits.split(file_names, labels))\n",
"train_idx = split[0]\n",
"valid_idx = split[1]\n",
"submission_predictions = []\n",
"len(train_idx), len(valid_idx)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"# train data generator\n",
"data_generator_train = TrainDataGenerator(train_final_df.iloc[train_idx], \n",
" train_final_df.iloc[train_idx], \n",
" TRAIN_BATCH_SIZE, \n",
" (WIDTH, HEIGHT),\n",
" augment = True)\n",
"\n",
"# validation data generator\n",
"data_generator_val = TrainDataGenerator(train_final_df.iloc[valid_idx], \n",
" train_final_df.iloc[valid_idx], \n",
" VALID_BATCH_SIZE, \n",
" (WIDTH, HEIGHT),\n",
" augment = False)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(19888, 635)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(data_generator_train), len(data_generator_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Competition evaluation metric is evaluated based on weighted log loss but we haven't given weights for each subtype but as per discussion from this thread https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/discussion/109526#latest-630190 any has a wieght of 2 than other types below sample is taken from the discussion threas"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from keras import backend as K\n",
"\n",
"def weighted_log_loss(y_true, y_pred):\n",
" \"\"\"\n",
" Can be used as the loss function in model.compile()\n",
" ---------------------------------------------------\n",
" \"\"\"\n",
" \n",
" class_weights = np.array([2., 1., 1., 1., 1., 1.])\n",
" \n",
" eps = K.epsilon()\n",
" \n",
" y_pred = K.clip(y_pred, eps, 1.0-eps)\n",
"\n",
" out = -( y_true * K.log( y_pred) * class_weights\n",
" + (1.0 - y_true) * K.log(1.0 - y_pred) * class_weights)\n",
" \n",
" return K.mean(out, axis=-1)\n",
"\n",
"\n",
"def _normalized_weighted_average(arr, weights=None):\n",
" \"\"\"\n",
" A simple Keras implementation that mimics that of \n",
" numpy.average(), specifically for this competition\n",
" \"\"\"\n",
" \n",
" if weights is not None:\n",
" scl = K.sum(weights)\n",
" weights = K.expand_dims(weights, axis=1)\n",
" return K.sum(K.dot(arr, weights), axis=1) / scl\n",
" return K.mean(arr, axis=1)\n",
"\n",
"\n",
"def weighted_loss(y_true, y_pred):\n",
" \"\"\"\n",
" Will be used as the metric in model.compile()\n",
" ---------------------------------------------\n",
" \n",
" Similar to the custom loss function 'weighted_log_loss()' above\n",
" but with normalized weights, which should be very similar \n",
" to the official competition metric:\n",
" https://www.kaggle.com/kambarakun/lb-probe-weights-n-of-positives-scoring\n",
" and hence:\n",
" sklearn.metrics.log_loss with sample weights\n",
" \"\"\"\n",
" \n",
" class_weights = K.variable([2., 1., 1., 1., 1., 1.])\n",
" \n",
" eps = K.epsilon()\n",
" \n",
" y_pred = K.clip(y_pred, eps, 1.0-eps)\n",
"\n",
" loss = -( y_true * K.log( y_pred)\n",
" + (1.0 - y_true) * K.log(1.0 - y_pred))\n",
" \n",
" loss_samples = _normalized_weighted_average(loss, class_weights)\n",
" \n",
" return K.mean(loss_samples)\n",
"\n",
"\n",
"def weighted_log_loss_metric(trues, preds):\n",
" \"\"\"\n",
" Will be used to calculate the log loss \n",
" of the validation set in PredictionCheckpoint()\n",
" ------------------------------------------\n",
" \"\"\"\n",
" class_weights = [2., 1., 1., 1., 1., 1.]\n",
" \n",
" epsilon = 1e-7\n",
" \n",
" preds = np.clip(preds, epsilon, 1-epsilon)\n",
" loss = trues * np.log(preds) + (1 - trues) * np.log(1 - preds)\n",
" loss_samples = np.average(loss, axis=1, weights=class_weights)\n",
"\n",
" return - loss_samples.mean()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"filepath=\"model.h5\"\n",
"checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, \\\n",
" save_best_only=True, mode='min')\n",
"\n",
"callbacks_list = [checkpoint]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For a single epoch we are going to train only last 5 layers of Efficient. Since we have a large number of images around 600k so its better to train the all the layers on the whole train dataset but due its high computation resources required to train we only goin to train last five layers on whole dataset and for rest of epochs we only train on a sample of dataset but will train all the layers."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"train = False"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"if train:\n",
" if not os.path.isfile('../input/orginal-087-eff/model.h5'):\n",
" for layer in model.layers[:-5]:\n",
" layer.trainable = False\n",
" model.compile(optimizer = Adam(learning_rate = 0.0001), \n",
" loss = 'binary_crossentropy',\n",
" metrics = ['acc'])\n",
"\n",
" model.fit_generator(generator = data_generator_train,\n",
" validation_data = data_generator_val,\n",
" epochs = 2,\n",
" callbacks = callbacks_list,\n",
" verbose = 1)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"if train:\n",
" for base_layer in model.layers[:-1]:\n",
" base_layer.trainable = True\n",
"\n",
" model.load_weights('model.h5')\n",
"\n",
" model.compile(optimizer = Adam(learning_rate = 0.0004), \n",
" loss = 'binary_crossentropy',\n",
" metrics = ['acc'])\n",
" model.fit_generator(generator = data_generator_train,\n",
" validation_data = data_generator_val,\n",
" steps_per_epoch=len(data_generator_train)/6,\n",
" epochs = 10,\n",
" callbacks = callbacks_list,\n",
" verbose = 1)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting gdown\n",
" Downloading https://files.pythonhosted.org/packages/b0/b4/a8e9d0b02bca6aa53087001abf064cc9992bda11bd6840875b8098d93573/gdown-3.8.3.tar.gz\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.6/site-packages (from gdown) (3.0.12)\n",
"Requirement already satisfied: requests in /opt/conda/lib/python3.6/site-packages (from gdown) (2.22.0)\n",
"Requirement already satisfied: six in /opt/conda/lib/python3.6/site-packages (from gdown) (1.12.0)\n",
"Requirement already satisfied: tqdm in /opt/conda/lib/python3.6/site-packages (from gdown) (4.36.1)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.6/site-packages (from requests->gdown) (1.24.2)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/lib/python3.6/site-packages (from requests->gdown) (3.0.4)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/lib/python3.6/site-packages (from requests->gdown) (2.8)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.6/site-packages (from requests->gdown) (2019.9.11)\n",
"Building wheels for collected packages: gdown\n",
" Building wheel for gdown (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for gdown: filename=gdown-3.8.3-cp36-none-any.whl size=8850 sha256=ca7bf131547dd1503032ee6ec7567ff06fb7ddad8d44a32f00f874aadbd01a5e\n",
" Stored in directory: /tmp/.cache/pip/wheels/a7/9d/16/9e0bda9a327ff2cddaee8de48a27553fb1efce73133593d066\n",
"Successfully built gdown\n",
"Installing collected packages: gdown\n",
"Successfully installed gdown-3.8.3\n"
]
}
],
"source": [
"!pip install gdown"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1kZmMCCBOWSjCZjz2XWaouDIj5gFn2D-q\n",
"To: /kaggle/working/model (4).h5\n",
"49.2MB [00:03, 14.6MB/s]\n"
]
}
],
"source": [
"!gdown https://drive.google.com/uc?id=1kZmMCCBOWSjCZjz2XWaouDIj5gFn2D-q"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"!cp \"model (4).h5\" model.h5"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1228/1228 [==============================] - 856s 697ms/step\n"
]
},
{
"data": {
"text/plain": [
"(78592, 6)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_weights('model.h5')\n",
"\n",
"preds = model.predict_generator(TestDataGenerator(test_df.index, None, VALID_BATCH_SIZE, \\\n",
" (WIDTH, HEIGHT), path_test_img), \n",
" verbose=1)\n",
"preds.shape"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"cols = list(train_final_df.columns)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████▉| 78545/78592 [00:01<00:00, 51807.96it/s]\n"
]
}
],
"source": [
"# We have preditions for each of the image\n",
"# We need to make 6 rows for each of file according to the subtype\n",
"ids = []\n",
"values = []\n",
"for i, j in tqdm(zip(preds, test_df.index.to_list()), total=preds.shape[0]):\n",
"# print(i, j)\n",
" # i=[any_prob, epidural_prob, intraparenchymal_prob, intraventricular_prob, subarachnoid_prob, subdural_prob]\n",
" # j = filename ==> ID_xyz.dcm\n",
" for k in range(i.shape[0]):\n",
" ids.append([j.replace('.dcm', '_' + cols[k])])\n",
" values.append(i[k]) "
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ID_000012eaf_any</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ID_000012eaf_epidural</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ID_000012eaf_intraparenchymal</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ID_000012eaf_intraventricular</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ID_000012eaf_subarachnoid</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0\n",
"0 ID_000012eaf_any\n",
"1 ID_000012eaf_epidural\n",
"2 ID_000012eaf_intraparenchymal\n",
"3 ID_000012eaf_intraventricular\n",
"4 ID_000012eaf_subarachnoid"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(data=ids)\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>Label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ID_28fbab7eb_epidural</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ID_28fbab7eb_intraparenchymal</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ID_28fbab7eb_intraventricular</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ID_28fbab7eb_subarachnoid</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ID_28fbab7eb_subdural</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" ID Label\n",
"0 ID_28fbab7eb_epidural 0.5\n",
"1 ID_28fbab7eb_intraparenchymal 0.5\n",
"2 ID_28fbab7eb_intraventricular 0.5\n",
"3 ID_28fbab7eb_subarachnoid 0.5\n",
"4 ID_28fbab7eb_subdural 0.5"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample_df = pd.read_csv(input_folder + 'stage_1_sample_submission.csv')\n",
"sample_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>Label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ID_000012eaf_any</td>\n",
" <td>0.008506</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ID_000012eaf_epidural</td>\n",
" <td>0.000114</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ID_000012eaf_intraparenchymal</td>\n",
" <td>0.001682</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ID_000012eaf_intraventricular</td>\n",
" <td>0.000329</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ID_000012eaf_subarachnoid</td>\n",
" <td>0.000926</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" ID Label\n",
"0 ID_000012eaf_any 0.008506\n",
"1 ID_000012eaf_epidural 0.000114\n",
"2 ID_000012eaf_intraparenchymal 0.001682\n",
"3 ID_000012eaf_intraventricular 0.000329\n",
"4 ID_000012eaf_subarachnoid 0.000926"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['Label'] = values\n",
"df.columns = sample_df.columns\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"df.to_csv('submission.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<a href=submission.csv>Download CSV file</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"create_download_link(filename='submission.csv')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}