[0fc53f]: / Notebook / Model / InceptionV3_2.ipynb

Download this file

1048 lines (1047 with data), 63.2 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pickle\n",
    "import random\n",
    "import glob\n",
    "import datetime\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import cv2\n",
    "import pydicom\n",
    "from tqdm import tqdm\n",
    "from joblib import delayed, Parallel\n",
    "import zipfile\n",
    "from pydicom.filebase import DicomBytesIO\n",
    "import sys\n",
    "from PIL import Image\n",
    "import cv2\n",
    "#from focal_loss import sparse_categorical_focal_loss\n",
    "import keras\n",
    "#import tensorflow_addons as tfa\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from keras.models import model_from_json\n",
    "import tensorflow as tf\n",
    "import keras\n",
    "from keras.models import Sequential, Model\n",
    "from keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, GlobalAveragePooling2D, Dropout\n",
    "from keras.applications.inception_v3 import InceptionV3\n",
    "\n",
    "# importing pyplot and image from matplotlib \n",
    "import matplotlib.pyplot as plt \n",
    "import matplotlib.image as mpimg \n",
    "from keras.optimizers import SGD\n",
    "from keras import backend\n",
    "from keras.models import load_model\n",
    "\n",
    "from keras.preprocessing import image\n",
    "import albumentations as A\n",
    "\n",
    "\n",
    "from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.losses import Reduction\n",
    "\n",
    "from tensorflow_addons.losses import SigmoidFocalCrossEntropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_url = '/home/ubuntu/kaggle/rsna-intracranial-hemorrhage-detection/'\n",
    "TRAIN_DIR = '/home/ubuntu/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train/'\n",
    "TEST_DIR = '/home/ubuntu/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_test/'\n",
    "image_dir = '/home/ubuntu/kaggle/rsna-intracranial-hemorrhage-detection/png/train/adjacent-brain-cropped/'\n",
    "save_dir = 'home/ubuntu/kaggle/models/'\n",
    "os.listdir(base_url)\n",
    "\n",
    "def png(image): \n",
    "    return image + '.png'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning rate scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_learning_rate = 1e-2\n",
    "first_decay_steps = 1000\n",
    "lr_decayed_fn = (\n",
    "  tf.keras.experimental.CosineDecayRestarts(\n",
    "      initial_learning_rate,\n",
    "      first_decay_steps))\n",
    "opt = tf.keras.optimizers.Adam(learning_rate=lr_decayed_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Weighted_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "def _normalized_weighted_average(arr, weights=None):\n",
    "    \"\"\"\n",
    "    A simple Keras implementation that mimics that of \n",
    "    numpy.average(), specifically for this competition\n",
    "    \"\"\"\n",
    "    \n",
    "    if weights is not None:\n",
    "        scl = K.sum(weights)\n",
    "        weights = K.expand_dims(weights, axis=1)\n",
    "        return K.sum(K.dot(arr, weights), axis=1) / scl\n",
    "    return K.mean(arr, axis=1)\n",
    "\n",
    "\n",
    "def weighted_loss(y_true, y_pred):\n",
    "    \"\"\"\n",
    "    Will be used as the metric in model.compile()\n",
    "    ---------------------------------------------\n",
    "    \n",
    "    Similar to the custom loss function 'weighted_log_loss()' above\n",
    "    but with normalized weights, which should be very similar \n",
    "    to the official competition metric:\n",
    "        https://www.kaggle.com/kambarakun/lb-probe-weights-n-of-positives-scoring\n",
    "    and hence:\n",
    "        sklearn.metrics.log_loss with sample weights\n",
    "    \"\"\"\n",
    "    \n",
    "    class_weights = K.variable([2., 1., 1., 1., 1., 1.])\n",
    "    \n",
    "    eps = K.epsilon()\n",
    "    \n",
    "    y_pred = K.clip(y_pred, eps, 1.0-eps)\n",
    "\n",
    "    loss = -(        y_true  * K.log(      y_pred)\n",
    "            + (1.0 - y_true) * K.log(1.0 - y_pred))\n",
    "    \n",
    "    loss_samples = _normalized_weighted_average(loss, class_weights)\n",
    "    \n",
    "    return K.mean(loss_samples)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 604739 validated image filenames.\n",
      "Found 148063 validated image filenames.\n"
     ]
    }
   ],
   "source": [
    "train_idg = ImageDataGenerator(\n",
    "        featurewise_center=False,  # set input mean to 0 over the dataset\n",
    "        samplewise_center=False,  # set each sample mean to 0\n",
    "        featurewise_std_normalization=False,  # divide inputs by std of the dataset\n",
    "        samplewise_std_normalization=False,  # divide each input by its std\n",
    "        zca_whitening=False,  # apply ZCA whitening\n",
    "        rotation_range=50,  # randomly rotate images in the range (degrees, 0 to 180)\n",
    "        width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)\n",
    "        height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)\n",
    "        horizontal_flip=True,\n",
    "        rescale=1./255)\n",
    "valid_idg = ImageDataGenerator(rescale=1./255)\n",
    "training_data = pd.read_csv(f'train_1.csv') \n",
    "training_data['Image'] = training_data['Image'].apply(png)\n",
    "\n",
    "validation_data = pd.read_csv(f'valid_1.csv')\n",
    "validation_data['Image'] = validation_data['Image'].apply(png)\n",
    "\n",
    "columns=['any','epidural','intraparenchymal','intraventricular', 'subarachnoid','subdural']\n",
    "\n",
    "train_data_generator = train_idg.flow_from_dataframe(training_data, directory = image_dir,\n",
    "                           x_col = \"Image\", y_col = columns,batch_size=64,\n",
    "                           class_mode=\"raw\", target_size=(224,224), shuffle = True)\n",
    "valid_data_generator  = valid_idg.flow_from_dataframe(validation_data, directory = image_dir,\n",
    "                        x_col = \"Image\", y_col = columns,batch_size=64,\n",
    "                        class_mode = \"raw\",target_size=(224,224), shuffle = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"functional_1\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "conv1_conv (Conv2D)             (None, 112, 112, 64) 9472        conv1_pad[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv1_bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1_conv[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv1_relu (Activation)         (None, 112, 112, 64) 0           conv1_bn[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "pool1_pad (ZeroPadding2D)       (None, 114, 114, 64) 0           conv1_relu[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "pool1_pool (MaxPooling2D)       (None, 56, 56, 64)   0           pool1_pad[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_1_conv (Conv2D)    (None, 56, 56, 64)   4160        pool1_pool[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_1_relu (Activation (None, 56, 56, 64)   0           conv2_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_2_relu (Activation (None, 56, 56, 64)   0           conv2_block1_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_0_conv (Conv2D)    (None, 56, 56, 256)  16640       pool1_pool[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block1_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_0_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block1_0_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block1_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_add (Add)          (None, 56, 56, 256)  0           conv2_block1_0_bn[0][0]          \n",
      "                                                                 conv2_block1_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block1_out (Activation)   (None, 56, 56, 256)  0           conv2_block1_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_1_conv (Conv2D)    (None, 56, 56, 64)   16448       conv2_block1_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_1_relu (Activation (None, 56, 56, 64)   0           conv2_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_2_relu (Activation (None, 56, 56, 64)   0           conv2_block2_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block2_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block2_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_add (Add)          (None, 56, 56, 256)  0           conv2_block1_out[0][0]           \n",
      "                                                                 conv2_block2_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block2_out (Activation)   (None, 56, 56, 256)  0           conv2_block2_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_1_conv (Conv2D)    (None, 56, 56, 64)   16448       conv2_block2_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_1_relu (Activation (None, 56, 56, 64)   0           conv2_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_2_relu (Activation (None, 56, 56, 64)   0           conv2_block3_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block3_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block3_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_add (Add)          (None, 56, 56, 256)  0           conv2_block2_out[0][0]           \n",
      "                                                                 conv2_block3_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv2_block3_out (Activation)   (None, 56, 56, 256)  0           conv2_block3_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_1_conv (Conv2D)    (None, 28, 28, 128)  32896       conv2_block3_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_1_relu (Activation (None, 28, 28, 128)  0           conv3_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_2_relu (Activation (None, 28, 28, 128)  0           conv3_block1_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_0_conv (Conv2D)    (None, 28, 28, 512)  131584      conv2_block3_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block1_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_0_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block1_0_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block1_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_add (Add)          (None, 28, 28, 512)  0           conv3_block1_0_bn[0][0]          \n",
      "                                                                 conv3_block1_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block1_out (Activation)   (None, 28, 28, 512)  0           conv3_block1_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block1_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_1_relu (Activation (None, 28, 28, 128)  0           conv3_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_2_relu (Activation (None, 28, 28, 128)  0           conv3_block2_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block2_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block2_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_add (Add)          (None, 28, 28, 512)  0           conv3_block1_out[0][0]           \n",
      "                                                                 conv3_block2_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block2_out (Activation)   (None, 28, 28, 512)  0           conv3_block2_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block2_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_1_relu (Activation (None, 28, 28, 128)  0           conv3_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_2_relu (Activation (None, 28, 28, 128)  0           conv3_block3_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block3_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block3_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_add (Add)          (None, 28, 28, 512)  0           conv3_block2_out[0][0]           \n",
      "                                                                 conv3_block3_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block3_out (Activation)   (None, 28, 28, 512)  0           conv3_block3_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block3_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block4_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_1_relu (Activation (None, 28, 28, 128)  0           conv3_block4_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block4_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block4_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_2_relu (Activation (None, 28, 28, 128)  0           conv3_block4_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block4_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block4_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_add (Add)          (None, 28, 28, 512)  0           conv3_block3_out[0][0]           \n",
      "                                                                 conv3_block4_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv3_block4_out (Activation)   (None, 28, 28, 512)  0           conv3_block4_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_1_conv (Conv2D)    (None, 14, 14, 256)  131328      conv3_block4_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_1_relu (Activation (None, 14, 14, 256)  0           conv4_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_2_relu (Activation (None, 14, 14, 256)  0           conv4_block1_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_0_conv (Conv2D)    (None, 14, 14, 1024) 525312      conv3_block4_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block1_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_0_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block1_0_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block1_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_add (Add)          (None, 14, 14, 1024) 0           conv4_block1_0_bn[0][0]          \n",
      "                                                                 conv4_block1_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block1_out (Activation)   (None, 14, 14, 1024) 0           conv4_block1_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block1_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_1_relu (Activation (None, 14, 14, 256)  0           conv4_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_2_relu (Activation (None, 14, 14, 256)  0           conv4_block2_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block2_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block2_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_add (Add)          (None, 14, 14, 1024) 0           conv4_block1_out[0][0]           \n",
      "                                                                 conv4_block2_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block2_out (Activation)   (None, 14, 14, 1024) 0           conv4_block2_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block2_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_1_relu (Activation (None, 14, 14, 256)  0           conv4_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_2_relu (Activation (None, 14, 14, 256)  0           conv4_block3_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block3_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block3_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_add (Add)          (None, 14, 14, 1024) 0           conv4_block2_out[0][0]           \n",
      "                                                                 conv4_block3_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block3_out (Activation)   (None, 14, 14, 1024) 0           conv4_block3_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block3_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block4_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_1_relu (Activation (None, 14, 14, 256)  0           conv4_block4_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block4_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block4_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_2_relu (Activation (None, 14, 14, 256)  0           conv4_block4_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block4_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block4_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_add (Add)          (None, 14, 14, 1024) 0           conv4_block3_out[0][0]           \n",
      "                                                                 conv4_block4_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block4_out (Activation)   (None, 14, 14, 1024) 0           conv4_block4_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block4_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block5_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_1_relu (Activation (None, 14, 14, 256)  0           conv4_block5_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block5_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block5_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_2_relu (Activation (None, 14, 14, 256)  0           conv4_block5_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block5_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block5_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_add (Add)          (None, 14, 14, 1024) 0           conv4_block4_out[0][0]           \n",
      "                                                                 conv4_block5_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block5_out (Activation)   (None, 14, 14, 1024) 0           conv4_block5_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block5_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block6_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_1_relu (Activation (None, 14, 14, 256)  0           conv4_block6_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block6_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block6_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_2_relu (Activation (None, 14, 14, 256)  0           conv4_block6_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block6_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block6_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_add (Add)          (None, 14, 14, 1024) 0           conv4_block5_out[0][0]           \n",
      "                                                                 conv4_block6_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv4_block6_out (Activation)   (None, 14, 14, 1024) 0           conv4_block6_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_1_conv (Conv2D)    (None, 7, 7, 512)    524800      conv4_block6_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block1_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_1_relu (Activation (None, 7, 7, 512)    0           conv5_block1_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block1_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block1_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_2_relu (Activation (None, 7, 7, 512)    0           conv5_block1_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_0_conv (Conv2D)    (None, 7, 7, 2048)   2099200     conv4_block6_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block1_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_0_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block1_0_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block1_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_add (Add)          (None, 7, 7, 2048)   0           conv5_block1_0_bn[0][0]          \n",
      "                                                                 conv5_block1_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block1_out (Activation)   (None, 7, 7, 2048)   0           conv5_block1_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_1_conv (Conv2D)    (None, 7, 7, 512)    1049088     conv5_block1_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block2_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_1_relu (Activation (None, 7, 7, 512)    0           conv5_block2_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block2_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block2_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_2_relu (Activation (None, 7, 7, 512)    0           conv5_block2_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block2_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block2_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_add (Add)          (None, 7, 7, 2048)   0           conv5_block1_out[0][0]           \n",
      "                                                                 conv5_block2_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block2_out (Activation)   (None, 7, 7, 2048)   0           conv5_block2_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_1_conv (Conv2D)    (None, 7, 7, 512)    1049088     conv5_block2_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block3_1_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_1_relu (Activation (None, 7, 7, 512)    0           conv5_block3_1_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block3_1_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block3_2_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_2_relu (Activation (None, 7, 7, 512)    0           conv5_block3_2_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block3_2_relu[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block3_3_conv[0][0]        \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_add (Add)          (None, 7, 7, 2048)   0           conv5_block2_out[0][0]           \n",
      "                                                                 conv5_block3_3_bn[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "conv5_block3_out (Activation)   (None, 7, 7, 2048)   0           conv5_block3_add[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "global_average_pooling2d (Globa (None, 2048)         0           conv5_block3_out[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "dropout (Dropout)               (None, 2048)         0           global_average_pooling2d[0][0]   \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (None, 6)            12294       dropout[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 23,600,006\n",
      "Trainable params: 12,294\n",
      "Non-trainable params: 23,587,712\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras.applications.resnet50 import ResNet50\n",
    "\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout\n",
    "\n",
    "METRICS = [\n",
    "      tf.keras.metrics.TruePositives(name='tp'),\n",
    "      tf.keras.metrics.FalsePositives(name='fp'),\n",
    "      tf.keras.metrics.TrueNegatives(name='tn'),\n",
    "      tf.keras.metrics.FalseNegatives(name='fn'), \n",
    "      tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
    "      tf.keras.metrics.Precision(name='precision'),\n",
    "      tf.keras.metrics.Recall(name='recall'),\n",
    "      tf.keras.metrics.AUC(name='auc')\n",
    "      \n",
    "]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# create the base pre-trained model\n",
    "base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224,224,3))\n",
    "\n",
    "# add a global spatial average pooling layer\n",
    "x = base_model.output\n",
    "x = GlobalAveragePooling2D()(x)\n",
    "# let's add a fully-connected layer\n",
    "#x = Dense(256, activation='relu')(x)\n",
    "# and a logistic layer -- let's say we have 200 classes\n",
    "x = Dropout(0.3)(x)\n",
    "#x = Dense(100, activation=\"relu\")(x)\n",
    "#x = Dropout(0.3)(x)\n",
    "#pred = Dense(6,\n",
    "#             kernel_initializer=tf.keras.initializers.HeNormal(seed=11),\n",
    "#             kernel_regularizer=tf.keras.regularizers.l2(0.05),\n",
    "#             bias_regularizer=tf.keras.regularizers.l2(0.05), activation=\"softmax\")(x)\n",
    "#initializer = keras.initializers.GlorotUniform()\n",
    "#layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)\n",
    "\n",
    "predictions = Dense(6, activation='softmax')(x)\n",
    "#activation='sigmoid',kernel_initializer=keras.initializers.GlorotNormal()\n",
    "# this is the model we will train\n",
    "model = Model(inputs=base_model.input, outputs=predictions)\n",
    "\n",
    "# first: train only the top layers (which were randomly initialized)\n",
    "# i.e. freeze all convolutional InceptionV3 layers\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# compile the model (should be done *after* setting layers to non-trainable)\n",
    "model.compile(opt, loss=SigmoidFocalCrossEntropy(), metrics= METRICS)\n",
    "#model.compile(loss=\"binary_crossentropy\", optimizer=keras.optimizers.Adam(), metrics=METRICS)\n",
    "#model.compile(loss=loss_func,\n",
    "#          optimizer=opt,\n",
    "#          metrics=METRICS)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.\n"
     ]
    }
   ],
   "source": [
    "from keras import backend as K\n",
    "\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint\n",
    "\n",
    "\n",
    "checkpoint = tf.keras.callbacks.ModelCheckpoint('resnet50{epoch:08d}.h5', period=1,mode= 'auto',save_best_only=True) \n",
    "\n",
    "learning_rate_reduction = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_acc', \n",
    "                                            patience=1, \n",
    "                                            verbose=1, \n",
    "                                            factor=0.5, \n",
    "                                            min_lr=0.00001)\n",
    "\n",
    "callback_list = [checkpoint]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-10-bdb3352f611a>:1: experimental_run_functions_eagerly (from tensorflow.python.eager.def_function) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.config.run_functions_eagerly` instead of the experimental version.\n"
     ]
    }
   ],
   "source": [
    "tf.config.experimental_run_functions_eagerly(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# model fit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py:3350: UserWarning: Even though the tf.config.experimental_run_functions_eagerly option is set, this option does not apply to tf.data functions. tf.data functions are still traced and executed as graphs.\n",
      "  \"Even though the tf.config.experimental_run_functions_eagerly \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "4724/4724 [==============================] - 5855s 1s/step - loss: 0.6288 - tp: 6640.0000 - fp: 51458.0000 - tn: 1658587.0000 - fn: 96965.0000 - accuracy: 0.9182 - precision: 0.1143 - recall: 0.0641 - auc: 0.5777 - val_loss: 0.6626 - val_tp: 1.0000 - val_fp: 30.0000 - val_tn: 419245.0000 - val_fn: 24628.0000 - val_accuracy: 0.9445 - val_precision: 0.0323 - val_recall: 4.0603e-05 - val_auc: 0.6240\n",
      "Epoch 2/3\n",
      "4724/4724 [==============================] - 5560s 1s/step - loss: 0.7789 - tp: 17593.0000 - fp: 138643.0000 - tn: 1572177.0000 - fn: 85237.0000 - accuracy: 0.8766 - precision: 0.1126 - recall: 0.1711 - auc: 0.6140 - val_loss: 0.8053 - val_tp: 8792.0000 - val_fp: 65190.0000 - val_tn: 354085.0000 - val_fn: 15837.0000 - val_accuracy: 0.8175 - val_precision: 0.1188 - val_recall: 0.3570 - val_auc: 0.6379\n",
      "Epoch 3/3\n",
      "4724/4724 [==============================] - 5500s 1s/step - loss: 0.8205 - tp: 38155.0000 - fp: 264178.0000 - tn: 1446569.0000 - fn: 65114.0000 - accuracy: 0.8185 - precision: 0.1262 - recall: 0.3695 - auc: 0.6422 - val_loss: 0.8036 - val_tp: 9835.0000 - val_fp: 64144.0000 - val_tn: 355131.0000 - val_fn: 14794.0000 - val_accuracy: 0.8222 - val_precision: 0.1329 - val_recall: 0.3993 - val_auc: 0.6529\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "num_epochs = 3\n",
    "\n",
    "batch_size = 128\n",
    "training_steps = len(training_data) // batch_size\n",
    "validation_step = len(validation_data) // batch_size\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# FIT THE MODEL\n",
    "history = model.fit(train_data_generator,\n",
    "            epochs=num_epochs,steps_per_epoch=training_steps,\n",
    "            callbacks=callback_list,\n",
    "            validation_data=valid_data_generator,\n",
    "            validation_steps= validation_step) \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "tf.keras.backend.clear_session()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.keras.backend.clear_session()\n",
    "#from keras.models import load_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.load_model('resnet5000000001.h5')"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "#from sklearn.neural_network import BernoulliRBM\n",
    "#rbm = BernoulliRBM(n_components=100, learning_rate=0.01, random_state=0, verbose=True)\n",
    "#rbm.fit(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-12-6ece0724cedc>:2: Model.predict_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use Model.predict, which supports generators.\n"
     ]
    }
   ],
   "source": [
    "X_val_sample, val_labels = next(valid_data_generator)\n",
    "predictions = model.predict_generator(valid_data_generator,steps = 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['any','epidural','intraparenchymal','intraventricular', 'subarachnoid','subdural']\n",
    "cm = confusion_matrix(validation_data['any','epidural','intraparenchymal','intraventricular', 'subarachnoid','subdural'], prediction_, labels)\n",
    "print(cm)\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111)\n",
    "cax = ax.matshow(cm)\n",
    "plt.title('Confusion matrix of the classifier')\n",
    "fig.colorbar(cax)\n",
    "ax.set_xticklabels([''] + labels)\n",
    "ax.set_yticklabels([''] + labels)\n",
    "plt.xlabel('Predicted')\n",
    "plt.ylabel('True')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "test_predictions_weighted = model.predict(valid_data_generator, batch_size=BATCH_SIZE)\n",
    "\n",
    "\n",
    "weighted_results = model.evaluate(test_features, test_labels,\n",
    "                                           batch_size=BATCH_SIZE, verbose=0)\n",
    "for name, value in zip(model.metrics_names, weighted_results):\n",
    "    print(name, ': ', value)\n",
    "print()\n",
    "\n",
    "plot_cm(test_labels, test_predictions_weighted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.6691614985466003, 1.0, 89.0, 838395.0, 49893.0, 0.9437377452850342, 0.011111111380159855, 2.004248926823493e-05, 0.6232836246490479]\n"
     ]
    }
   ],
   "source": [
    "valid_predict =  model.evaluate_generator(valid_data_generator)\n",
    "print(valid_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['loss', 'tp', 'fp', 'tn', 'fn', 'accuracy', 'precision', 'recall', 'auc']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.metrics_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot([0, 1], [0, 1], linestyle='--')\n",
    "# plot the roc curve for the model\n",
    "plt.plot(valid_predict[2], valid_predict[1], marker='.')\n",
    "# show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_training(H):\n",
    "    # construct a plot that plots and saves the training history\n",
    "    with plt.xkcd():\n",
    "        plt.figure(figsize = (10,10))\n",
    "        plt.plot(H.epoch,H.history[\"acc\"], label=\"train_acc\")\n",
    "        plt.plot(H.epoch,H.history[\"val_acc\"], label=\"val_acc\")\n",
    "        plt.title(\"Training Accuracy\")\n",
    "        plt.xlabel(\"Epoch #\")\n",
    "        plt.ylabel(\"Accuracy\")\n",
    "        plt.legend(loc=\"lower left\")\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_training(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_training(H):\n",
    "    # construct a plot that plots and saves the training history\n",
    "    with plt.xkcd():\n",
    "        plt.figure(figsize = (10,10))\n",
    "        plt.plot(H.epoch,H.history[\"loss\"], label=\"train_loss\")\n",
    "        plt.plot(H.epoch,H.history[\"val_loss\"], label=\"val_loss\")\n",
    "        plt.title(\"Training Loss\")\n",
    "        plt.xlabel(\"Epoch #\")\n",
    "        plt.ylabel(\"Loss\")\n",
    "        plt.legend(loc=\"lower left\")\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_training(history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metrics(history):\n",
    "    plt.figure(figsize = (15,15))\n",
    "    metrics =  ['loss', 'tp', 'fp', 'tn', 'fn', 'accuracy', 'precision', 'recall', 'auc']\n",
    "    for n, metric in enumerate(metrics):\n",
    "        \n",
    "        name = metric.replace(\"_\",\" \").capitalize()\n",
    "        plt.subplot(2,2,n+1)\n",
    "        \n",
    "        plt.plot(history.epoch,  history.history[metric], color='red', label='Train')\n",
    "        plt.plot(history.epoch, history.history['val_'+metric],\n",
    "                 color='orange', linestyle=\"--\", label='Val')\n",
    "        plt.xlabel('Epoch')\n",
    "        plt.ylabel(name)\n",
    "        if metric == 'loss':\n",
    "            plt.ylim([0, plt.ylim()[1]])\n",
    "        elif metric == 'auc':\n",
    "            plt.ylim([0.8,1])\n",
    "        else:\n",
    "            plt.ylim([0,1])\n",
    "\n",
    "        plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_training(history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}