[198e90]: / leukemia detection.ipynb

Download this file

943 lines (942 with data), 67.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "59654c10",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "modules loaded\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras.layers import Dense, Activation,Dropout,Conv2D, MaxPooling2D,BatchNormalization, Flatten\n",
    "from tensorflow.keras.optimizers import Adam, Adamax\n",
    "from tensorflow.keras.metrics import categorical_crossentropy\n",
    "from tensorflow.keras import regularizers\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.models import Model, load_model, Sequential\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import shutil\n",
    "import time\n",
    "import cv2 as cv2\n",
    "from tqdm import tqdm\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.pyplot import imshow\n",
    "import seaborn as sns\n",
    "sns.set_style('darkgrid')\n",
    "from PIL import Image\n",
    "from sklearn.metrics import confusion_matrix, classification_report\n",
    "from IPython.core.display import display, HTML\n",
    "# stop annoying tensorflow warning messages\n",
    "import logging\n",
    "logging.getLogger(\"tensorflow\").setLevel(logging.ERROR)\n",
    "print ('modules loaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fa7de5f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_image_samples(gen ):\n",
    "    t_dict=gen.class_indices\n",
    "    classes=list(t_dict.keys())    \n",
    "    images,labels=next(gen) # get a sample batch from the generator \n",
    "    plt.figure(figsize=(20, 20))\n",
    "    length=len(labels)\n",
    "    if length<25:   #show maximum of 25 images\n",
    "        r=length\n",
    "    else:\n",
    "        r=25\n",
    "    for i in range(r):\n",
    "        plt.subplot(5, 5, i + 1)\n",
    "        image=images[i]/255\n",
    "        plt.imshow(image)\n",
    "        index=np.argmax(labels[i])\n",
    "        class_name=classes[index]\n",
    "        plt.title(class_name, color='blue', fontsize=12)\n",
    "        plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3cc88fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_images(tdir):\n",
    "    classlist=os.listdir(tdir)\n",
    "    length=len(classlist)\n",
    "    columns=5\n",
    "    rows=int(np.ceil(length/columns))    \n",
    "    plt.figure(figsize=(20, rows * 4))\n",
    "    for i, klass in enumerate(classlist):    \n",
    "        classpath=os.path.join(tdir, klass)\n",
    "        imgpath=os.path.join(classpath, '1.jpg')\n",
    "        img=plt.imread(imgpath)\n",
    "        plt.subplot(rows, columns, i+1)\n",
    "        plt.axis('off')\n",
    "        plt.title(klass, color='blue', fontsize=12)\n",
    "        plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "564a3b4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_in_color(txt_msg,fore_tupple,back_tupple,):\n",
    "    #prints the text_msg in the foreground color specified by fore_tupple with the background specified by back_tupple \n",
    "    #text_msg is the text, fore_tupple is foregroud color tupple (r,g,b), back_tupple is background tupple (r,g,b)\n",
    "    rf,gf,bf=fore_tupple\n",
    "    rb,gb,bb=back_tupple\n",
    "    msg='{0}' + txt_msg\n",
    "    mat='\\33[38;2;' + str(rf) +';' + str(gf) + ';' + str(bf) + ';48;2;' + str(rb) + ';' +str(gb) + ';' + str(bb) +'m' \n",
    "    print(msg .format(mat), flush=True)\n",
    "    print('\\33[0m', flush=True) # returns default print color to back to black\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9bbdb8ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LRA(keras.callbacks.Callback):\n",
    "    def __init__(self,model, base_model, patience,stop_patience, threshold, factor, dwell, batches, initial_epoch,epochs, ask_epoch):\n",
    "        super(LRA, self).__init__()\n",
    "        self.model=model\n",
    "        self.base_model=base_model\n",
    "        self.patience=patience # specifies how many epochs without improvement before learning rate is adjusted\n",
    "        self.stop_patience=stop_patience # specifies how many times to adjust lr without improvement to stop training\n",
    "        self.threshold=threshold # specifies training accuracy threshold when lr will be adjusted based on validation loss\n",
    "        self.factor=factor # factor by which to reduce the learning rate\n",
    "        self.dwell=dwell\n",
    "        self.batches=batches # number of training batch to runn per epoch\n",
    "        self.initial_epoch=initial_epoch\n",
    "        self.epochs=epochs\n",
    "        self.ask_epoch=ask_epoch\n",
    "        self.ask_epoch_initial=ask_epoch # save this value to restore if restarting training\n",
    "        # callback variables \n",
    "        self.count=0 # how many times lr has been reduced without improvement\n",
    "        self.stop_count=0        \n",
    "        self.best_epoch=1   # epoch with the lowest loss        \n",
    "        self.initial_lr=float(tf.keras.backend.get_value(model.optimizer.lr)) # get the initiallearning rate and save it         \n",
    "        self.highest_tracc=0.0 # set highest training accuracy to 0 initially\n",
    "        self.lowest_vloss=np.inf # set lowest validation loss to infinity initially\n",
    "        self.best_weights=self.model.get_weights() # set best weights to model's initial weights\n",
    "        self.initial_weights=self.model.get_weights()   # save initial weights if they have to get restored \n",
    "        \n",
    "    def on_train_begin(self, logs=None):        \n",
    "        if self.base_model != None:\n",
    "            status=base_model.trainable\n",
    "            if status:\n",
    "                msg=' initializing callback starting training with base_model trainable'\n",
    "            else:\n",
    "                msg='initializing callback starting training with base_model not trainable'\n",
    "        else:\n",
    "            msg='initialing callback and starting training'                        \n",
    "        print_in_color (msg, (244, 252, 3), (55,65,80)) \n",
    "        msg='{0:^8s}{1:^10s}{2:^9s}{3:^9s}{4:^9s}{5:^9s}{6:^9s}{7:^10s}{8:10s}{9:^8s}'.format('Epoch', 'Loss', 'Accuracy',\n",
    "                                                                                              'V_loss','V_acc', 'LR', 'Next LR', 'Monitor','% Improv', 'Duration')\n",
    "        print_in_color(msg, (244,252,3), (55,65,80)) \n",
    "        self.start_time= time.time()\n",
    "        \n",
    "    def on_train_end(self, logs=None):\n",
    "        stop_time=time.time()\n",
    "        tr_duration= stop_time- self.start_time            \n",
    "        hours = tr_duration // 3600\n",
    "        minutes = (tr_duration - (hours * 3600)) // 60\n",
    "        seconds = tr_duration - ((hours * 3600) + (minutes * 60))\n",
    "\n",
    "        self.model.set_weights(self.best_weights) # set the weights of the model to the best weights\n",
    "        msg=f'Training is completed - model is set with weights from epoch {self.best_epoch} '\n",
    "        print_in_color(msg, (0,255,0), (55,65,80))\n",
    "        msg = f'training elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds)'\n",
    "        print_in_color(msg, (0,255,0), (55,65,80))   \n",
    "        \n",
    "    def on_train_batch_end(self, batch, logs=None):\n",
    "        acc=logs.get('accuracy')* 100  # get training accuracy \n",
    "        loss=logs.get('loss')\n",
    "        msg='{0:20s}processing batch {1:4s} of {2:5s} accuracy= {3:8.3f}  loss: {4:8.5f}'.format(' ', str(batch), str(self.batches), acc, loss)\n",
    "        print(msg, '\\r', end='') # prints over on the same line to show running batch count        \n",
    "        \n",
    "    def on_epoch_begin(self,epoch, logs=None):\n",
    "        self.now= time.time()\n",
    "        \n",
    "    def on_epoch_end(self, epoch, logs=None):  # method runs on the end of each epoch\n",
    "        later=time.time()\n",
    "        duration=later-self.now \n",
    "        lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate\n",
    "        current_lr=lr\n",
    "        v_loss=logs.get('val_loss')  # get the validation loss for this epoch\n",
    "        acc=logs.get('accuracy')  # get training accuracy \n",
    "        v_acc=logs.get('val_accuracy')\n",
    "        loss=logs.get('loss')        \n",
    "        if acc < self.threshold: # if training accuracy is below threshold adjust lr based on training accuracy\n",
    "            monitor='accuracy'\n",
    "            if epoch ==0:\n",
    "                pimprov=0.0\n",
    "            else:\n",
    "                pimprov= (acc-self.highest_tracc )*100/self.highest_tracc\n",
    "            if acc>self.highest_tracc: # training accuracy improved in the epoch                \n",
    "                self.highest_tracc=acc # set new highest training accuracy\n",
    "                self.best_weights=self.model.get_weights() # traing accuracy improved so save the weights\n",
    "                self.count=0 # set count to 0 since training accuracy improved\n",
    "                self.stop_count=0 # set stop counter to 0\n",
    "                if v_loss<self.lowest_vloss:\n",
    "                    self.lowest_vloss=v_loss\n",
    "                color= (0,255,0)\n",
    "                self.best_epoch=epoch + 1  # set the value of best epoch for this epoch              \n",
    "            else: \n",
    "                # training accuracy did not improve check if this has happened for patience number of epochs\n",
    "                # if so adjust learning rate\n",
    "                if self.count>=self.patience -1: # lr should be adjusted\n",
    "                    color=(245, 170, 66)\n",
    "                    lr= lr* self.factor # adjust the learning by factor\n",
    "                    tf.keras.backend.set_value(self.model.optimizer.lr, lr) # set the learning rate in the optimizer\n",
    "                    self.count=0 # reset the count to 0\n",
    "                    self.stop_count=self.stop_count + 1 # count the number of consecutive lr adjustments\n",
    "                    self.count=0 # reset counter\n",
    "                    if self.dwell:\n",
    "                        self.model.set_weights(self.best_weights) # return to better point in N space                        \n",
    "                    else:\n",
    "                        if v_loss<self.lowest_vloss:\n",
    "                            self.lowest_vloss=v_loss                                    \n",
    "                else:\n",
    "                    self.count=self.count +1 # increment patience counter                    \n",
    "        else: # training accuracy is above threshold so adjust learning rate based on validation loss\n",
    "            monitor='val_loss'\n",
    "            if epoch ==0:\n",
    "                pimprov=0.0\n",
    "            else:\n",
    "                pimprov= (self.lowest_vloss- v_loss )*100/self.lowest_vloss\n",
    "            if v_loss< self.lowest_vloss: # check if the validation loss improved \n",
    "                self.lowest_vloss=v_loss # replace lowest validation loss with new validation loss                \n",
    "                self.best_weights=self.model.get_weights() # validation loss improved so save the weights\n",
    "                self.count=0 # reset count since validation loss improved  \n",
    "                self.stop_count=0  \n",
    "                color=(0,255,0)                \n",
    "                self.best_epoch=epoch + 1 # set the value of the best epoch to this epoch\n",
    "            else: # validation loss did not improve\n",
    "                if self.count>=self.patience-1: # need to adjust lr\n",
    "                    color=(245, 170, 66)\n",
    "                    lr=lr * self.factor # adjust the learning rate                    \n",
    "                    self.stop_count=self.stop_count + 1 # increment stop counter because lr was adjusted \n",
    "                    self.count=0 # reset counter\n",
    "                    tf.keras.backend.set_value(self.model.optimizer.lr, lr) # set the learning rate in the optimizer\n",
    "                    if self.dwell:\n",
    "                        self.model.set_weights(self.best_weights) # return to better point in N space\n",
    "                else: \n",
    "                    self.count =self.count +1 # increment the patience counter                    \n",
    "                if acc>self.highest_tracc:\n",
    "                    self.highest_tracc= acc\n",
    "        msg=f'{str(epoch+1):^3s}/{str(self.epochs):4s} {loss:^9.3f}{acc*100:^9.3f}{v_loss:^9.5f}{v_acc*100:^9.3f}{current_lr:^9.5f}{lr:^9.5f}{monitor:^11s}{pimprov:^10.2f}{duration:^8.2f}'\n",
    "        print_in_color (msg,color, (55,65,80))\n",
    "        if self.stop_count> self.stop_patience - 1: # check if learning rate has been adjusted stop_count times with no improvement\n",
    "            msg=f' training has been halted at epoch {epoch + 1} after {self.stop_patience} adjustments of learning rate with no improvement'\n",
    "            print_in_color(msg, (0,255,255), (55,65,80))\n",
    "            self.model.stop_training = True # stop training\n",
    "        else: \n",
    "            if self.ask_epoch !=None:\n",
    "                if epoch + 1 >= self.ask_epoch:\n",
    "                    if base_model.trainable:\n",
    "                        msg='enter H to halt training or an integer for number of epochs to run then ask again'\n",
    "                    else:\n",
    "                        msg='enter H to halt training ,F to fine tune model, or an integer for number of epochs to run then ask again'\n",
    "                    print_in_color(msg, (0,255,255), (55,65,80))\n",
    "                    ans=input('')\n",
    "                    if ans=='H' or ans=='h':\n",
    "                        msg=f'training has been halted at epoch {epoch + 1} due to user input'\n",
    "                        print_in_color(msg, (0,255,255), (55,65,80))\n",
    "                        self.model.stop_training = True # stop training\n",
    "                    elif ans == 'F' or ans=='f':\n",
    "                        if base_model.trainable:\n",
    "                            msg='base_model is already set as trainable'\n",
    "                        else:\n",
    "                            msg='setting base_model as trainable for fine tuning of model'\n",
    "                            self.base_model.trainable=True\n",
    "                        print_in_color(msg, (0, 255,255), (55,65,80))\n",
    "                        msg='{0:^8s}{1:^10s}{2:^9s}{3:^9s}{4:^9s}{5:^9s}{6:^9s}{7:^10s}{8:^8s}'.format('Epoch', 'Loss', 'Accuracy',\n",
    "                                                                                              'V_loss','V_acc', 'LR', 'Next LR', 'Monitor','% Improv', 'Duration')\n",
    "                        print_in_color(msg, (244,252,3), (55,65,80))                         \n",
    "                        self.count=0\n",
    "                        self.stop_count=0                        \n",
    "                        self.ask_epoch = epoch + 1 + self.ask_epoch_initial \n",
    "                        \n",
    "                    else:\n",
    "                        ans=int(ans)\n",
    "                        self.ask_epoch +=ans\n",
    "                        msg=f' training will continue until epoch ' + str(self.ask_epoch)                         \n",
    "                        print_in_color(msg, (0, 255,255), (55,65,80))\n",
    "                        msg='{0:^8s}{1:^10s}{2:^9s}{3:^9s}{4:^9s}{5:^9s}{6:^9s}{7:^10s}{8:10s}{9:^8s}'.format('Epoch', 'Loss', 'Accuracy',\n",
    "                                                                                              'V_loss','V_acc', 'LR', 'Next LR', 'Monitor','% Improv', 'Duration')\n",
    "                        print_in_color(msg, (244,252,3), (55,65,80))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e155beef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tr_plot(tr_data, start_epoch):\n",
    "    #Plot the training and validation data\n",
    "    tacc=tr_data.history['accuracy']\n",
    "    tloss=tr_data.history['loss']\n",
    "    vacc=tr_data.history['val_accuracy']\n",
    "    vloss=tr_data.history['val_loss']\n",
    "    Epoch_count=len(tacc)+ start_epoch\n",
    "    Epochs=[]\n",
    "    for i in range (start_epoch ,Epoch_count):\n",
    "        Epochs.append(i+1)   \n",
    "    index_loss=np.argmin(vloss)#  this is the epoch with the lowest validation loss\n",
    "    val_lowest=vloss[index_loss]\n",
    "    index_acc=np.argmax(vacc)\n",
    "    acc_highest=vacc[index_acc]\n",
    "    plt.style.use('fivethirtyeight')\n",
    "    sc_label='best epoch= '+ str(index_loss+1 +start_epoch)\n",
    "    vc_label='best epoch= '+ str(index_acc + 1+ start_epoch)\n",
    "    fig,axes=plt.subplots(nrows=1, ncols=2, figsize=(20,8))\n",
    "    axes[0].plot(Epochs,tloss, 'r', label='Training loss')\n",
    "    axes[0].plot(Epochs,vloss,'g',label='Validation loss' )\n",
    "    axes[0].scatter(index_loss+1 +start_epoch,val_lowest, s=150, c= 'blue', label=sc_label)\n",
    "    axes[0].set_title('Training and Validation Loss')\n",
    "    axes[0].set_xlabel('Epochs')\n",
    "    axes[0].set_ylabel('Loss')\n",
    "    axes[0].legend()\n",
    "    axes[1].plot (Epochs,tacc,'r',label= 'Training Accuracy')\n",
    "    axes[1].plot (Epochs,vacc,'g',label= 'Validation Accuracy')\n",
    "    axes[1].scatter(index_acc+1 +start_epoch,acc_highest, s=150, c= 'blue', label=vc_label)\n",
    "    axes[1].set_title('Training and Validation Accuracy')\n",
    "    axes[1].set_xlabel('Epochs')\n",
    "    axes[1].set_ylabel('Accuracy')\n",
    "    axes[1].legend()\n",
    "    plt.tight_layout\n",
    "    #plt.style.use('fivethirtyeight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "35f70802",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_info( test_gen, preds, print_code, save_dir, subject ):\n",
    "    class_dict=test_gen.class_indices\n",
    "    labels= test_gen.labels\n",
    "    file_names= test_gen.filenames \n",
    "    error_list=[]\n",
    "    true_class=[]\n",
    "    pred_class=[]\n",
    "    prob_list=[]\n",
    "    new_dict={}\n",
    "    error_indices=[]\n",
    "    y_pred=[]\n",
    "    for key,value in class_dict.items():\n",
    "        new_dict[value]=key             # dictionary {integer of class number: string of class name}\n",
    "    # store new_dict as a text fine in the save_dir\n",
    "    classes=list(new_dict.values())     # list of string of class names     \n",
    "    errors=0      \n",
    "    for i, p in enumerate(preds):\n",
    "        pred_index=np.argmax(p)         \n",
    "        true_index=labels[i]  # labels are integer values\n",
    "        if pred_index != true_index: # a misclassification has occurred\n",
    "            error_list.append(file_names[i])\n",
    "            true_class.append(new_dict[true_index])\n",
    "            pred_class.append(new_dict[pred_index])\n",
    "            prob_list.append(p[pred_index])\n",
    "            error_indices.append(true_index)            \n",
    "            errors=errors + 1\n",
    "        y_pred.append(pred_index) \n",
    "    tests=len(preds)\n",
    "    acc= (1-errors/tests) *100\n",
    "    msg= f'There were {errors} errors in {tests} test cases Model accuracy= {acc: 6.2f} %'\n",
    "    print_in_color(msg,(0,255,255),(55,65,80))\n",
    "    if print_code !=0:\n",
    "        if errors>0:\n",
    "            if print_code>errors:\n",
    "                r=errors\n",
    "            else:\n",
    "                r=print_code           \n",
    "            msg='{0:^28s}{1:^28s}{2:^28s}{3:^16s}'.format('Filename', 'Predicted Class' , 'True Class', 'Probability')\n",
    "            print_in_color(msg, (0,255,0),(55,65,80))\n",
    "            for i in range(r):                \n",
    "                split1=os.path.split(error_list[i])                \n",
    "                split2=os.path.split(split1[0])                \n",
    "                fname=split2[1] + '/' + split1[1]\n",
    "                msg='{0:^28s}{1:^28s}{2:^28s}{3:4s}{4:^6.4f}'.format(fname, pred_class[i],true_class[i], ' ', prob_list[i])\n",
    "                print_in_color(msg, (255,255,255), (55,65,60))\n",
    "                #print(error_list[i]  , pred_class[i], true_class[i], prob_list[i])               \n",
    "        else:\n",
    "            msg='With accuracy of 100 % there are no errors to print'\n",
    "            print_in_color(msg, (0,255,0),(55,65,80))\n",
    "    if errors>0:\n",
    "        plot_bar=[]\n",
    "        plot_class=[]\n",
    "        for  key, value in new_dict.items():        \n",
    "            count=error_indices.count(key) \n",
    "            if count!=0:\n",
    "                plot_bar.append(count) # list containg how many times a class c had an error\n",
    "                plot_class.append(value)   # stores the class \n",
    "        fig=plt.figure()\n",
    "        fig.set_figheight(len(plot_class)/3)\n",
    "        fig.set_figwidth(10)\n",
    "        plt.style.use('fivethirtyeight')\n",
    "        for i in range(0, len(plot_class)):\n",
    "            c=plot_class[i]\n",
    "            x=plot_bar[i]\n",
    "            plt.barh(c, x, )\n",
    "            plt.title( ' Errors by Class on Test Set')\n",
    "    y_true= np.array(labels)        \n",
    "    y_pred=np.array(y_pred)\n",
    "    if len(classes)<= 30:\n",
    "        # create a confusion matrix \n",
    "        cm = confusion_matrix(y_true, y_pred )        \n",
    "        length=len(classes)\n",
    "        if length<8:\n",
    "            fig_width=8\n",
    "            fig_height=8\n",
    "        else:\n",
    "            fig_width= int(length * .5)\n",
    "            fig_height= int(length * .5)\n",
    "        plt.figure(figsize=(fig_width, fig_height))\n",
    "        sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)       \n",
    "        plt.xticks(np.arange(length)+.5, classes, rotation= 90)\n",
    "        plt.yticks(np.arange(length)+.5, classes, rotation=0)\n",
    "        plt.xlabel(\"Predicted\")\n",
    "        plt.ylabel(\"Actual\")\n",
    "        plt.title(\"Confusion Matrix\")\n",
    "        plt.show()\n",
    "    clr = classification_report(y_true, y_pred, target_names=classes, digits= 4)\n",
    "    print(\"Classification Report:\\n----------------------\\n\", clr)\n",
    "    return acc/100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e7d27934",
   "metadata": {},
   "outputs": [],
   "source": [
    "def saver(save_path, model, model_name, subject, accuracy,img_size, scalar, generator):    \n",
    "    # first save the model\n",
    "    save_id=str (model_name +  '-' + subject +'-'+ str(acc)[:str(acc).rfind('.')+3] + '.h5')\n",
    "    model_save_loc=os.path.join(save_path, save_id)\n",
    "    model.save(model_save_loc)\n",
    "    print_in_color ('model was saved as ' + model_save_loc, (0,255,0),(55,65,80)) \n",
    "    # now create the class_df and convert to csv file    \n",
    "    class_dict=generator.class_indices \n",
    "    height=[]\n",
    "    width=[]\n",
    "    scale=[]\n",
    "    for i in range(len(class_dict)):\n",
    "        height.append(img_size[0])\n",
    "        width.append(img_size[1])\n",
    "        scale.append(scalar)\n",
    "    Index_series=pd.Series(list(class_dict.values()), name='class_index')\n",
    "    Class_series=pd.Series(list(class_dict.keys()), name='class') \n",
    "    Height_series=pd.Series(height, name='height')\n",
    "    Width_series=pd.Series(width, name='width')\n",
    "    Scale_series=pd.Series(scale, name='scale by')\n",
    "    class_df=pd.concat([Index_series, Class_series, Height_series, Width_series, Scale_series], axis=1)    \n",
    "    csv_name='class_dict.csv'\n",
    "    csv_save_loc=os.path.join(save_path, csv_name)\n",
    "    class_df.to_csv(csv_save_loc, index=False) \n",
    "    print_in_color ('class csv file was saved as ' + csv_save_loc, (0,255,0),(55,65,80)) \n",
    "    return model_save_loc, csv_save_loc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6268a332",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predictor(sdir, csv_path,  model_path, averaged=True, verbose=True):    \n",
    "    # read in the csv file\n",
    "    class_df=pd.read_csv(csv_path)    \n",
    "    class_count=len(class_df['class'].unique())\n",
    "    img_height=int(class_df['height'].iloc[0])\n",
    "    img_width =int(class_df['width'].iloc[0])\n",
    "    img_size=(img_width, img_height)    \n",
    "    scale=class_df['scale by'].iloc[0]    \n",
    "    # determine value to scale image pixels by\n",
    "    try: \n",
    "        s=int(scale)\n",
    "        s2=1\n",
    "        s1=0\n",
    "    except:\n",
    "        split=scale.split('-')\n",
    "        s1=float(split[1])\n",
    "        s2=float(split[0].split('*')[1])\n",
    "    path_list=[]\n",
    "    paths=os.listdir(sdir)    \n",
    "    for f in paths:\n",
    "        path_list.append(os.path.join(sdir,f))\n",
    "    if verbose:\n",
    "        print (' Model is being loaded- this will take about 10 seconds')\n",
    "    model=load_model(model_path)\n",
    "    image_count=len(path_list) \n",
    "    image_list=[]\n",
    "    file_list=[]\n",
    "    good_image_count=0\n",
    "    for i in range (image_count):        \n",
    "        try:\n",
    "            img=cv2.imread(path_list[i])\n",
    "            img=cv2.resize(img, img_size)\n",
    "            img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)            \n",
    "            good_image_count +=1\n",
    "            img=img*s2 - s1             \n",
    "            image_list.append(img)\n",
    "            file_name=os.path.split(path_list[i])[1]\n",
    "            file_list.append(file_name)\n",
    "        except:\n",
    "            if verbose:\n",
    "                print ( path_list[i], ' is an invalid image file')\n",
    "    if good_image_count==1: # if only a single image need to expand dimensions\n",
    "        averaged=True\n",
    "    image_array=np.array(image_list)    \n",
    "    # make predictions on images, sum the probabilities of each class then find class index with\n",
    "    # highest probability\n",
    "    preds=model.predict(image_array)    \n",
    "    if averaged:\n",
    "        psum=[]\n",
    "        for i in range (class_count): # create all 0 values list\n",
    "            psum.append(0)    \n",
    "        for p in preds: # iterate over all predictions\n",
    "            for i in range (class_count):\n",
    "                psum[i]=psum[i] + p[i]  # sum the probabilities   \n",
    "        index=np.argmax(psum) # find the class index with the highest probability sum        \n",
    "        klass=class_df['class'].iloc[index] # get the class name that corresponds to the index\n",
    "        prob=psum[index]/good_image_count * 100  # get the probability average         \n",
    "        # to show the correct image run predict again and select first image that has same index\n",
    "        for img in image_array:  #iterate through the images    \n",
    "            test_img=np.expand_dims(img, axis=0) # since it is a single image expand dimensions \n",
    "            test_index=np.argmax(model.predict(test_img)) # for this image find the class index with highest probability\n",
    "            if test_index== index: # see if this image has the same index as was selected previously\n",
    "                if verbose: # show image and print result if verbose=1\n",
    "                    plt.axis('off')\n",
    "                    plt.imshow(img) # show the image\n",
    "                    print (f'predicted species is {klass} with a probability of {prob:6.4f} % ')\n",
    "                break # found an image that represents the predicted class      \n",
    "        return klass, prob, img, None\n",
    "    else: # create individual predictions for each image\n",
    "        pred_class=[]\n",
    "        prob_list=[]\n",
    "        for i, p in enumerate(preds):\n",
    "            index=np.argmax(p) # find the class index with the highest probability sum\n",
    "            klass=class_df['class'].iloc[index] # get the class name that corresponds to the index\n",
    "            image_file= file_list[i]\n",
    "            pred_class.append(klass)\n",
    "            prob_list.append(p[index])            \n",
    "        Fseries=pd.Series(file_list, name='image file')\n",
    "        Lseries=pd.Series(pred_class, name= 'species')\n",
    "        Pseries=pd.Series(prob_list, name='probability')\n",
    "        df=pd.concat([Fseries, Lseries, Pseries], axis=1)\n",
    "        if verbose:\n",
    "            length= len(df)\n",
    "            print (df.head(length))\n",
    "        return None, None, None, df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "516211e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def trim (df, max_size, min_size, column):\n",
    "    df=df.copy()\n",
    "    original_class_count= len(list(df[column].unique()))\n",
    "    print ('Original Number of classes in dataframe: ', original_class_count)\n",
    "    sample_list=[] \n",
    "    groups=df.groupby(column)\n",
    "    for label in df[column].unique():        \n",
    "        group=groups.get_group(label)\n",
    "        sample_count=len(group)         \n",
    "        if sample_count> max_size :\n",
    "            strat=group[column]\n",
    "            samples,_=train_test_split(group, train_size=max_size, shuffle=True, random_state=123, stratify=strat)            \n",
    "            sample_list.append(samples)\n",
    "        elif sample_count>= min_size:\n",
    "            sample_list.append(group)\n",
    "    df=pd.concat(sample_list, axis=0).reset_index(drop=True)\n",
    "    final_class_count= len(list(df[column].unique())) \n",
    "    if final_class_count != original_class_count:\n",
    "        print ('*** WARNING***  dataframe has a reduced number of classes' )\n",
    "    balance=list(df[column].value_counts())\n",
    "    print (balance)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0b66bd08",
   "metadata": {},
   "outputs": [],
   "source": [
    "def balance(train_df,max_samples, min_samples, column, working_dir, image_size):\n",
    "    train_df=train_df.copy()\n",
    "    train_df=trim (train_df, max_samples, min_samples, column)    \n",
    "    # make directories to store augmented images\n",
    "    aug_dir=os.path.join(working_dir, 'aug')\n",
    "    if os.path.isdir(aug_dir):\n",
    "        shutil.rmtree(aug_dir)\n",
    "    os.mkdir(aug_dir)\n",
    "    for label in train_df['labels'].unique():    \n",
    "        dir_path=os.path.join(aug_dir,label)    \n",
    "        os.mkdir(dir_path)\n",
    "    # create and store the augmented images  \n",
    "    total=0\n",
    "    gen=ImageDataGenerator(horizontal_flip=True,  rotation_range=20, width_shift_range=.2,\n",
    "                                  height_shift_range=.2, zoom_range=.2)\n",
    "    groups=train_df.groupby('labels') # group by class\n",
    "    for label in train_df['labels'].unique():  # for every class               \n",
    "        group=groups.get_group(label)  # a dataframe holding only rows with the specified label \n",
    "        sample_count=len(group)   # determine how many samples there are in this class  \n",
    "        if sample_count< max_samples: # if the class has less than target number of images\n",
    "            aug_img_count=0\n",
    "            delta=max_samples-sample_count  # number of augmented images to create\n",
    "            target_dir=os.path.join(aug_dir, label)  # define where to write the images    \n",
    "            aug_gen=gen.flow_from_dataframe( group,  x_col='filepaths', y_col=None, target_size=image_size,\n",
    "                                            class_mode=None, batch_size=1, shuffle=False, \n",
    "                                            save_to_dir=target_dir, save_prefix='aug-', color_mode='rgb',\n",
    "                                            save_format='jpg')\n",
    "            while aug_img_count<delta:\n",
    "                images=next(aug_gen)            \n",
    "                aug_img_count += len(images)\n",
    "            total +=aug_img_count\n",
    "    print('Total Augmented images created= ', total)\n",
    "    # create aug_df and merge with train_df to create composite training set ndf\n",
    "    if total>0:\n",
    "        aug_fpaths=[]\n",
    "        aug_labels=[]\n",
    "        classlist=os.listdir(aug_dir)\n",
    "        for klass in classlist:\n",
    "            classpath=os.path.join(aug_dir, klass)     \n",
    "            flist=os.listdir(classpath)    \n",
    "            for f in flist:        \n",
    "                fpath=os.path.join(classpath,f)         \n",
    "                aug_fpaths.append(fpath)\n",
    "                aug_labels.append(klass)\n",
    "        Fseries=pd.Series(aug_fpaths, name='filepaths')\n",
    "        Lseries=pd.Series(aug_labels, name='labels')\n",
    "        aug_df=pd.concat([Fseries, Lseries], axis=1)\n",
    "        train_df=pd.concat([train_df,aug_df], axis=0).reset_index(drop=True)\n",
    "   \n",
    "    print (list(train_df['labels'].value_counts()) )\n",
    "    return train_df "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8fc60cf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input image shape is  (450, 450, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x154c002b0a0>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "img_path=r'C:\\Users\\Vinay Wadhwa\\Downloads\\archive\\C-NMC_Leukemia\\training_data\\fold_0\\all\\UID_11_10_1_all.bmp'\n",
    "img=plt.imread(img_path)\n",
    "print ('Input image shape is ',img.shape)\n",
    "plt.axis('off')\n",
    "imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ed426a89",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess (sdir, trsplit, vsplit):\n",
    "    filepaths=[]\n",
    "    labels=[]    \n",
    "    folds=os.listdir(sdir)\n",
    "    for fold in folds:\n",
    "        foldpath=os.path.join(sdir,fold)\n",
    "        classlist=os.listdir(foldpath)\n",
    "        for klass in classlist:\n",
    "            classpath=os.path.join(foldpath,klass)\n",
    "            flist=os.listdir(classpath)\n",
    "            for f in flist:\n",
    "                fpath=os.path.join(classpath,f)\n",
    "                filepaths.append(fpath)\n",
    "                labels.append(klass)\n",
    "    Fseries=pd.Series(filepaths, name='filepaths')\n",
    "    Lseries=pd.Series(labels, name='labels')\n",
    "    df=pd.concat([Fseries, Lseries], axis=1)            \n",
    "    dsplit=vsplit/(1-trsplit)\n",
    "    strat=df['labels']\n",
    "    train_df, dummy_df=train_test_split(df, train_size=trsplit, shuffle=True, random_state=123, stratify=strat)\n",
    "    strat=dummy_df['labels']\n",
    "    valid_df, test_df= train_test_split(dummy_df, train_size=dsplit, shuffle=True, random_state=123, stratify=strat)\n",
    "    print('train_df length: ', len(train_df), '  test_df length: ',len(test_df), '  valid_df length: ', len(valid_df))\n",
    "     # check that each dataframe has the same number of classes to prevent model.fit errors\n",
    "    trcount=len(train_df['labels'].unique())\n",
    "    tecount=len(test_df['labels'].unique())\n",
    "    vcount=len(valid_df['labels'].unique())\n",
    "    if trcount < tecount :         \n",
    "        msg='** WARNING ** number of classes in training set is less than the number of classes in test set'\n",
    "        print_in_color(msg, (255,0,0), (55,65,80))\n",
    "        msg='This will throw an error in either model.evaluate or model.predict'\n",
    "        print_in_color(msg, (255,0,0), (55,65,80))\n",
    "    if trcount != vcount:\n",
    "        msg='** WARNING ** number of classes in training set not equal to number of classes in validation set' \n",
    "        print_in_color(msg, (255,0,0), (55,65,80))\n",
    "        msg=' this will throw an error in model.fit'\n",
    "        print_in_color(msg, (255,0,0), (55,65,80))\n",
    "        print ('train df class count: ', trcount, 'test df class count: ', tecount, ' valid df class count: ', vcount) \n",
    "        ans=input('Enter C to continue execution or H to halt execution')\n",
    "        if ans =='H' or ans == 'h':\n",
    "            print_in_color('Halting Execution', (255,0,0), (55,65,80))\n",
    "            import sys\n",
    "            sys.exit('program halted by user')            \n",
    "    print(list(train_df['labels'].value_counts()))\n",
    "    return train_df, test_df, valid_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beb2abf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "sdir=r'C:\\Users\\Vinay Wadhwa\\Downloads\\archive\\C-NMC_Leukemia\\training_data'\n",
    "trsplit=.9\n",
    "vsplit=.05\n",
    "train_df, test_df, valid_df= preprocess(sdir,trsplit, vsplit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c3cd360",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_samples= 3050\n",
    "min_samples=0\n",
    "column='labels'\n",
    "working_dir = r'./'\n",
    "img_size=(300,300)\n",
    "train_df=trim(train_df, max_samples, min_samples, column)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e337f302",
   "metadata": {},
   "outputs": [],
   "source": [
    "channels=3\n",
    "batch_size=10\n",
    "img_shape=(img_size[0], img_size[1], channels)\n",
    "length=len(test_df)\n",
    "test_batch_size=sorted([int(length/n) for n in range(1,length+1) if length % n ==0 and length/n<=80],reverse=True)[0]  \n",
    "test_steps=int(length/test_batch_size)\n",
    "print ( 'test batch size: ' ,test_batch_size, '  test steps: ', test_steps)\n",
    "def scalar(img):    \n",
    "    return img  # EfficientNet expects pixelsin range 0 to 255 so no scaling is required\n",
    "trgen=ImageDataGenerator(preprocessing_function=scalar, horizontal_flip=True)\n",
    "tvgen=ImageDataGenerator(preprocessing_function=scalar)\n",
    "msg='                                                              for the train generator'\n",
    "print(msg, '\\r', end='') \n",
    "train_gen=trgen.flow_from_dataframe( train_df, x_col='filepaths', y_col='labels', target_size=img_size, class_mode='categorical',\n",
    "                                    color_mode='rgb', shuffle=True, batch_size=batch_size)\n",
    "msg='                                                              for the test generator'\n",
    "print(msg, '\\r', end='') \n",
    "test_gen=tvgen.flow_from_dataframe( test_df, x_col='filepaths', y_col='labels', target_size=img_size, class_mode='categorical',\n",
    "                                    color_mode='rgb', shuffle=False, batch_size=test_batch_size)\n",
    "msg='                                                             for the validation generator'\n",
    "print(msg, '\\r', end='')\n",
    "valid_gen=tvgen.flow_from_dataframe( valid_df, x_col='filepaths', y_col='labels', target_size=img_size, class_mode='categorical',\n",
    "                                    color_mode='rgb', shuffle=True, batch_size=batch_size)\n",
    "classes=list(train_gen.class_indices.keys())\n",
    "class_count=len(classes)\n",
    "train_steps=int(np.ceil(len(train_gen.labels)/batch_size))\n",
    "labels=test_gen.labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45192270",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_image_samples(train_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be76b54c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='EfficientNetB3'\n",
    "base_model=tf.keras.applications.efficientnet.EfficientNetB3(include_top=False, weights=\"imagenet\",input_shape=img_shape, pooling='max') \n",
    "x=base_model.output\n",
    "x=keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001 )(x)\n",
    "x = Dense(32, kernel_regularizer = regularizers.l2(l = 0.016),activity_regularizer=regularizers.l1(0.006),\n",
    "                bias_regularizer=regularizers.l1(0.006) ,activation='relu')(x)\n",
    "x=Dropout(rate=.45, seed=123)(x)        \n",
    "output=Dense(class_count, activation='softmax')(x)\n",
    "model=Model(inputs=base_model.input, outputs=output)\n",
    "model.compile(Adamax(learning_rate=.001), loss='categorical_crossentropy', metrics=['accuracy']) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a18c5d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs =20\n",
    "patience= 1 # number of epochs to wait to adjust lr if monitored value does not improve\n",
    "stop_patience =3 # number of epochs to wait before stopping training if monitored value does not improve\n",
    "threshold=.9 # if train accuracy is < threshhold adjust monitor accuracy, else monitor validation loss\n",
    "factor=.5 # factor to reduce lr by\n",
    "dwell=True # experimental, if True and monitored metric does not improve on current epoch set  modelweights back to weights of previous epoch\n",
    "freeze=False # if true free weights of  the base model\n",
    "ask_epoch=5 # number of epochs to run before asking if you want to halt training\n",
    "batches=train_steps\n",
    "callbacks=[LRA(model=model,base_model= base_model,patience=patience,stop_patience=stop_patience, threshold=threshold,\n",
    "                   factor=factor,dwell=dwell, batches=batches,initial_epoch=0,epochs=epochs, ask_epoch=ask_epoch )]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af65bfd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "history=model.fit(x=train_gen,  epochs=epochs, verbose=0, callbacks=callbacks,  validation_data=valid_gen,\n",
    "               validation_steps=None,  shuffle=False,  initial_epoch=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf5aff88",
   "metadata": {},
   "outputs": [],
   "source": [
    "subject='leukemia'\n",
    "print_code=0\n",
    "preds=model.predict(test_gen) \n",
    "acc=print_info( test_gen, preds, print_code, working_dir, subject ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09616c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_save_loc, csv_save_loc=saver(working_dir, model, model_name, subject, acc, img_size, 1,  train_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4926bf84",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c14e1862",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}