[541a77]: / ModelNotebook.ipynb

Download this file

321 lines (320 with data), 11.6 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "857a5f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing the necessary libraries\n",
    "\n",
    "import tensorflow as tf\n",
    "import albumentations as albu\n",
    "import numpy as np\n",
    "import gc\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from keras.callbacks import CSVLogger\n",
    "from datetime import datetime\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import jaccard_score, precision_score, recall_score, accuracy_score, f1_score\n",
    "from ModelArchitecture.DiceLoss import dice_metric_loss\n",
    "from ModelArchitecture import DUCK_Net\n",
    "from ImageLoader import ImageLoader2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7852dbbf-f3a4-4db6-8c5e-98889bf3abd4",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# Checking the number of GPUs available\n",
    "\n",
    "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9edbe667",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting the model parameters\n",
    "\n",
    "img_size = 352\n",
    "dataset_type = 'kvasir' # Options: kvasir/cvc-clinicdb/cvc-colondb/etis-laribpolypdb\n",
    "learning_rate = 1e-4\n",
    "seed_value = 58800\n",
    "filters = 17 # Number of filters, the paper presents the results with 17 and 34\n",
    "optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)\n",
    "\n",
    "ct = datetime.now()\n",
    "\n",
    "model_type = \"DuckNet\"\n",
    "\n",
    "progress_path = 'ProgressFull/' + dataset_type + '_progress_csv_' + model_type + '_filters_' + str(filters) +  '_' + str(ct) + '.csv'\n",
    "progressfull_path = 'ProgressFull/' + dataset_type + '_progress_' + model_type + '_filters_' + str(filters) + '_' + str(ct) + '.txt'\n",
    "plot_path = 'ProgressFull/' + dataset_type + '_progress_plot_' + model_type + '_filters_' + str(filters) + '_' + str(ct) + '.png'\n",
    "model_path = 'ModelSaveTensorFlow/' + dataset_type + '/' + model_type + '_filters_' + str(filters) + '_' + str(ct)\n",
    "\n",
    "EPOCHS = 600\n",
    "min_loss_for_saving = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "264050d8",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# Loading the data\n",
    "\n",
    "X, Y = ImageLoader2D.load_data(img_size, img_size, -1, 'kvasir')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "75b6c029",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Splitting the data, seed for reproducibility\n",
    "\n",
    "x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.1, shuffle= True, random_state = seed_value)\n",
    "x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.111, shuffle= True, random_state = seed_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "92a7b7d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining the augmentations\n",
    "\n",
    "aug_train = albu.Compose([\n",
    "    albu.HorizontalFlip(),\n",
    "    albu.VerticalFlip(),\n",
    "    albu.ColorJitter(brightness=(0.6,1.6), contrast=0.2, saturation=0.1, hue=0.01, always_apply=True),\n",
    "    albu.Affine(scale=(0.5,1.5), translate_percent=(-0.125,0.125), rotate=(-180,180), shear=(-22.5,22), always_apply=True),\n",
    "])\n",
    "\n",
    "def augment_images():\n",
    "    x_train_out = []\n",
    "    y_train_out = []\n",
    "\n",
    "    for i in range (len(x_train)):\n",
    "        ug = aug_train(image=x_train[i], mask=y_train[i])\n",
    "        x_train_out.append(ug['image'])  \n",
    "        y_train_out.append(ug['mask'])\n",
    "\n",
    "    return np.array(x_train_out), np.array(y_train_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1609dd32",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# Creating the model\n",
    "\n",
    "model = DUCK_Net.create_model(img_height=img_size, img_width=img_size, input_chanels=3, out_classes=1, starting_filters=filters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2e513d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compiling the model\n",
    "\n",
    "model.compile(optimizer=optimizer, loss=dice_metric_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ef712b9",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# Training the model\n",
    "\n",
    "step = 0\n",
    "\n",
    "for epoch in range(0, EPOCHS):\n",
    "    \n",
    "    print(f'Training, epoch {epoch}')\n",
    "    print('Learning Rate: ' + str(learning_rate))\n",
    "\n",
    "    step += 1\n",
    "        \n",
    "    image_augmented, mask_augmented = augment_images()\n",
    "    \n",
    "    csv_logger = CSVLogger(progress_path, append=True, separator=';')\n",
    "    \n",
    "    model.fit(x=image_augmented, y=mask_augmented, epochs=1, batch_size=4, validation_data=(x_valid, y_valid), verbose=1, callbacks=[csv_logger])\n",
    "    \n",
    "    prediction_valid = model.predict(x_valid, verbose=0)\n",
    "    loss_valid = dice_metric_loss(y_valid, prediction_valid)\n",
    "    \n",
    "    loss_valid = loss_valid.numpy()\n",
    "    print(\"Loss Validation: \" + str(loss_valid))\n",
    "        \n",
    "    prediction_test = model.predict(x_test, verbose=0)\n",
    "    loss_test = dice_metric_loss(y_test, prediction_test)\n",
    "    loss_test = loss_test.numpy()\n",
    "    print(\"Loss Test: \" + str(loss_test))\n",
    "        \n",
    "    with open(progressfull_path, 'a') as f:\n",
    "        f.write('epoch: ' + str(epoch) + '\\nval_loss: ' + str(loss_valid) + '\\ntest_loss: ' + str(loss_test) + '\\n\\n\\n')\n",
    "    \n",
    "    if min_loss_for_saving > loss_valid:\n",
    "        min_loss_for_saving = loss_valid\n",
    "        print(\"Saved model with val_loss: \", loss_valid)\n",
    "        model.save(model_path)\n",
    "        \n",
    "    del image_augmented\n",
    "    del mask_augmented\n",
    "\n",
    "    gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd52447",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# Computing the metrics and saving the results\n",
    "\n",
    "print(\"Loading the model\")\n",
    "\n",
    "model = tf.keras.models.load_model(model_path, custom_objects={'dice_metric_loss':dice_metric_loss})\n",
    "\n",
    "prediction_train = model.predict(x_train, batch_size=4)\n",
    "prediction_valid = model.predict(x_valid, batch_size=4)\n",
    "prediction_test = model.predict(x_test, batch_size=4)\n",
    "\n",
    "print(\"Predictions done\")\n",
    "\n",
    "dice_train = f1_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
    "                           np.ndarray.flatten(prediction_train > 0.5))\n",
    "dice_test = f1_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
    "                          np.ndarray.flatten(prediction_test > 0.5))\n",
    "dice_valid = f1_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
    "                           np.ndarray.flatten(prediction_valid > 0.5))\n",
    "\n",
    "print(\"Dice finished\")\n",
    "\n",
    "\n",
    "miou_train = jaccard_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
    "                           np.ndarray.flatten(prediction_train > 0.5))\n",
    "miou_test = jaccard_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
    "                          np.ndarray.flatten(prediction_test > 0.5))\n",
    "miou_valid = jaccard_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
    "                           np.ndarray.flatten(prediction_valid > 0.5))\n",
    "\n",
    "print(\"Miou finished\")\n",
    "\n",
    "\n",
    "precision_train = precision_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
    "                                  np.ndarray.flatten(prediction_train > 0.5))\n",
    "precision_test = precision_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
    "                                 np.ndarray.flatten(prediction_test > 0.5))\n",
    "precision_valid = precision_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
    "                                  np.ndarray.flatten(prediction_valid > 0.5))\n",
    "\n",
    "print(\"Precision finished\")\n",
    "\n",
    "\n",
    "recall_train = recall_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
    "                            np.ndarray.flatten(prediction_train > 0.5))\n",
    "recall_test = recall_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
    "                           np.ndarray.flatten(prediction_test > 0.5))\n",
    "recall_valid = recall_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
    "                            np.ndarray.flatten(prediction_valid > 0.5))\n",
    "\n",
    "print(\"Recall finished\")\n",
    "\n",
    "\n",
    "accuracy_train = accuracy_score(np.ndarray.flatten(np.array(y_train, dtype=bool)),\n",
    "                                np.ndarray.flatten(prediction_train > 0.5))\n",
    "accuracy_test = accuracy_score(np.ndarray.flatten(np.array(y_test, dtype=bool)),\n",
    "                               np.ndarray.flatten(prediction_test > 0.5))\n",
    "accuracy_valid = accuracy_score(np.ndarray.flatten(np.array(y_valid, dtype=bool)),\n",
    "                                np.ndarray.flatten(prediction_valid > 0.5))\n",
    "\n",
    "\n",
    "print(\"Accuracy finished\")\n",
    "\n",
    "\n",
    "final_file = 'results_' + model_type + '_' + str(filters) + '_' + dataset_type + '.txt'\n",
    "print(final_file)\n",
    "\n",
    "with open(final_file, 'a') as f:\n",
    "    f.write(dataset_type + '\\n\\n')\n",
    "    f.write('dice_train: ' + str(dice_train) + ' dice_valid: ' + str(dice_valid) + ' dice_test: ' + str(dice_test) + '\\n\\n')\n",
    "    f.write('miou_train: ' + str(miou_train) + ' miou_valid: ' + str(miou_valid) + ' miou_test: ' + str(miou_test) + '\\n\\n')\n",
    "    f.write('precision_train: ' + str(precision_train) + ' precision_valid: ' + str(precision_valid) + ' precision_test: ' + str(precision_test) + '\\n\\n')\n",
    "    f.write('recall_train: ' + str(recall_train) + ' recall_valid: ' + str(recall_valid) + ' recall_test: ' + str(recall_test) + '\\n\\n')\n",
    "    f.write('accuracy_train: ' + str(accuracy_train) + ' accuracy_valid: ' + str(accuracy_valid) + ' accuracy_test: ' + str(accuracy_test) + '\\n\\n\\n\\n')\n",
    "\n",
    "print('File done')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}