--- a
+++ b/test_preds_function.ipynb
@@ -0,0 +1,333 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: efficientnet in /home/surya/anaconda3/lib/python3.6/site-packages (1.0.0)\n",
+      "Requirement already satisfied: scikit-image in /home/surya/anaconda3/lib/python3.6/site-packages (from efficientnet) (0.15.0)\n",
+      "Requirement already satisfied: keras-applications<=1.0.8,>=1.0.7 in /home/surya/anaconda3/lib/python3.6/site-packages (from efficientnet) (1.0.8)\n",
+      "Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.2.2)\n",
+      "Requirement already satisfied: scipy>=0.17.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (1.1.0)\n",
+      "Requirement already satisfied: imageio>=2.0.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.3.0)\n",
+      "Requirement already satisfied: PyWavelets>=0.4.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (0.5.2)\n",
+      "Requirement already satisfied: pillow>=4.3.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (6.2.0)\n",
+      "Requirement already satisfied: networkx>=2.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from scikit-image->efficientnet) (2.1)\n",
+      "Requirement already satisfied: h5py in /home/surya/anaconda3/lib/python3.6/site-packages (from keras-applications<=1.0.8,>=1.0.7->efficientnet) (2.7.1)\n",
+      "Requirement already satisfied: numpy>=1.9.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from keras-applications<=1.0.8,>=1.0.7->efficientnet) (1.17.3)\n",
+      "Requirement already satisfied: cycler>=0.10 in /home/surya/anaconda3/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (0.10.0)\n",
+      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2.2.0)\n",
+      "Requirement already satisfied: python-dateutil>=2.1 in /home/surya/.local/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2.7.3)\n",
+      "Requirement already satisfied: pytz in /home/surya/.local/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (2018.5)\n",
+      "Requirement already satisfied: six>=1.10 in /home/surya/.local/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (1.11.0)\n",
+      "Requirement already satisfied: kiwisolver>=1.0.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (1.0.1)\n",
+      "Requirement already satisfied: decorator>=4.1.0 in /home/surya/anaconda3/lib/python3.6/site-packages (from networkx>=2.0->scikit-image->efficientnet) (4.3.0)\n",
+      "Requirement already satisfied: setuptools in /home/surya/anaconda3/lib/python3.6/site-packages (from kiwisolver>=1.0.1->matplotlib!=3.0.0,>=2.0.0->scikit-image->efficientnet) (39.1.0)\n",
+      "\u001b[33mWARNING: You are using pip version 19.1.1, however version 19.3.1 is available.\n",
+      "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
+      "Requirement already satisfied: iterative-stratification in /home/surya/anaconda3/lib/python3.6/site-packages (0.1.6)\n",
+      "Requirement already satisfied: numpy in /home/surya/anaconda3/lib/python3.6/site-packages (from iterative-stratification) (1.17.3)\n",
+      "Requirement already satisfied: scipy in /home/surya/anaconda3/lib/python3.6/site-packages (from iterative-stratification) (1.1.0)\n",
+      "Requirement already satisfied: scikit-learn in /home/surya/anaconda3/lib/python3.6/site-packages (from iterative-stratification) (0.20.3)\n",
+      "\u001b[33mWARNING: You are using pip version 19.1.1, however version 19.3.1 is available.\n",
+      "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
+      "Requirement already satisfied: gdown in /home/surya/anaconda3/lib/python3.6/site-packages (3.8.1)\n",
+      "Requirement already satisfied: filelock in /home/surya/anaconda3/lib/python3.6/site-packages (from gdown) (3.0.4)\n",
+      "Requirement already satisfied: six in /home/surya/.local/lib/python3.6/site-packages (from gdown) (1.11.0)\n",
+      "Requirement already satisfied: tqdm in /home/surya/anaconda3/lib/python3.6/site-packages (from gdown) (4.28.1)\n",
+      "Requirement already satisfied: requests in /home/surya/anaconda3/lib/python3.6/site-packages (from gdown) (2.18.4)\n",
+      "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (3.0.4)\n",
+      "Requirement already satisfied: idna<2.7,>=2.5 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (2.6)\n",
+      "Requirement already satisfied: urllib3<1.23,>=1.21.1 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (1.22)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in /home/surya/anaconda3/lib/python3.6/site-packages (from requests->gdown) (2018.4.16)\n",
+      "\u001b[33mWARNING: You are using pip version 19.1.1, however version 19.3.1 is available.\n",
+      "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install efficientnet\n",
+    "!pip install iterative-stratification\n",
+    "!pip install gdown\n",
+    "\n",
+    "import os\n",
+    "if not os.path.isfile('model_effnet_bo_087.h5'):\n",
+    "    !gdown https://drive.google.com/uc?id=1FXF1HymYbRf3OlThMTXAa74TRup3AhD_"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import pydicom\n",
+    "import cv2\n",
+    "import tensorflow as tf\n",
+    "import multiprocessing\n",
+    "from math import ceil, floor\n",
+    "import keras\n",
+    "import keras.backend as K\n",
+    "from keras.callbacks import Callback, ModelCheckpoint\n",
+    "from keras.layers import Dense, Flatten, Dropout\n",
+    "from keras.models import Model, load_model\n",
+    "from keras.utils import Sequence\n",
+    "from keras.losses import binary_crossentropy\n",
+    "from keras.optimizers import Adam\n",
+    "import efficientnet.keras as efn \n",
+    "from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit\n",
+    "import matplotlib.pyplot as plt"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
+    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
+   },
+   "outputs": [],
+   "source": [
+    "HEIGHT = 256\n",
+    "WIDTH = 256\n",
+    "CHANNELS = 3\n",
+    "SHAPE = (HEIGHT, WIDTH, CHANNELS)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def correct_dcm(dcm):\n",
+    "    x = dcm.pixel_array + 1000\n",
+    "    px_mode = 4096\n",
+    "    x[x>=px_mode] = x[x>=px_mode] - px_mode\n",
+    "    dcm.PixelData = x.tobytes()\n",
+    "    dcm.RescaleIntercept = -1000\n",
+    "\n",
+    "def window_image(dcm, window_center, window_width):    \n",
+    "    if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):\n",
+    "        correct_dcm(dcm)\n",
+    "    img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept\n",
+    "    \n",
+    "    # Resize\n",
+    "    img = cv2.resize(img, SHAPE[:2], interpolation = cv2.INTER_LINEAR)\n",
+    "   \n",
+    "    img_min = window_center - window_width // 2\n",
+    "    img_max = window_center + window_width // 2\n",
+    "    img = np.clip(img, img_min, img_max)\n",
+    "    return img\n",
+    "\n",
+    "def bsb_window(dcm):\n",
+    "    brain_img = window_image(dcm, 40, 80)\n",
+    "    subdural_img = window_image(dcm, 80, 200)\n",
+    "    soft_img = window_image(dcm, 40, 380)\n",
+    "    \n",
+    "    brain_img = (brain_img - 0) / 80\n",
+    "    subdural_img = (subdural_img - (-20)) / 200\n",
+    "    soft_img = (soft_img - (-150)) / 380\n",
+    "    bsb_img = np.array([brain_img, subdural_img, soft_img]).transpose(1,2,0)\n",
+    "    return bsb_img\n",
+    "\n",
+    "def _read(path, SHAPE):\n",
+    "    dcm = pydicom.dcmread(path)\n",
+    "    try:\n",
+    "        img = bsb_window(dcm)\n",
+    "    except:\n",
+    "        img = np.zeros(SHAPE)\n",
+    "    return img"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import matplotlib.pyplot as plt"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Import the training and test datasets."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "base_model =  efn.EfficientNetB0(weights = 'imagenet', include_top = False, \\\n",
+    "                                 pooling = 'avg', input_shape = (HEIGHT, WIDTH, 3))\n",
+    "\n",
+    "x = base_model.output\n",
+    "x = Dropout(0.125)(x)\n",
+    "output_layer = Dense(6, activation = 'sigmoid')(x)\n",
+    "model = Model(inputs=base_model.input, outputs=output_layer)\n",
+    "model.compile(optimizer = Adam(lr = 0.0001), \n",
+    "              loss = 'binary_crossentropy',\n",
+    "              metrics = ['acc'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model.load_weights('model_effnet_bo_087.h5')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(256, 256, 3)"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "img_data = _read('ID_000000e27.dcm', (256, 256))\n",
+    "img_data.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.image.AxesImage at 0x7f9f047b2c18>"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "plt.imshow(img_data)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(1, 256, 256, 3)"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# Prepare the data\n",
+    "# 1 ==> Batch size\n",
+    "# 256, 256, 3 is the image shape\n",
+    "X = np.empty((1, 256,256, 3))\n",
+    "X[0] = img_data\n",
+    "X.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array([[0.05049768, 0.00077422, 0.00192805, 0.00035753, 0.0012597 ,\n",
+       "        0.02818136]], dtype=float32)"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "preds = model.predict(X)\n",
+    "preds"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(1, 6)"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "preds.shape"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}