Download this file

1695 lines (1694 with data), 413.2 kB

{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Environment and Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Libraries\n",
    "import cv2\n",
    "import os\n",
    "import glob\n",
    "import warnings\n",
    "import numpy as np\n",
    "import SimpleITK as sitk\n",
    "import matplotlib.pyplot as plt\n",
    "# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
    "from random import randint\n",
    "from tqdm import tqdm\n",
    "\n",
    "# import keras.api._v2.keras as keras\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "# from tensorflow.keras import layers\n",
    "from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, BatchNormalization, Activation, Conv2DTranspose, concatenate, Dropout\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from sklearn.model_selection import train_test_split\n",
    "from keras.models import Model, load_model\n",
    "from keras.optimizers import Adam\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "from scipy.spatial.distance import directed_hausdorff\n",
    "\n",
    "if False:\n",
    "  # Google drive\n",
    "  from google.colab import drive\n",
    "  drive.mount('/content/drive')\n",
    "\n",
    "  # WD for Data\n",
    "  os.getcwd()\n",
    "  os.chdir('/content/drive/MyDrive/Colab Notebooks')\n",
    "\n",
    "  # GPU\n",
    "  device_name = tf.test.gpu_device_name()\n",
    "  if device_name != '/device:GPU:0':\n",
    "    raise SystemError('GPU device not found')\n",
    "  print('Found GPU at: {}'.format(device_name))\n",
    "\n",
    "  # TPU\n",
    "  # resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n",
    "  # tf.config.experimental_connect_to_cluster(resolver)\n",
    "  # # This is the TPU initialization code that has to be at the beginning.\n",
    "  # tf.tpu.experimental.initialize_tpu_system(resolver)\n",
    "  # print(\"All devices: \", tf.config.list_logical_devices('TPU'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(581, 512, 512, 1)\n",
      "Not Augmented:  [0. 1. 3.]\n",
      "Augmented:  [0. 1. 2. 3.]\n"
     ]
    }
   ],
   "source": [
    "import nibabel as nib\n",
    "import numpy as np\n",
    "\n",
    "def find_unique_numbers(arr):\n",
    "    unique_numbers = np.unique(arr)\n",
    "    return unique_numbers\n",
    "\n",
    "\n",
    "nii_img_scan = nib.load(\"D:/MRI - Tairawhiti (User POV)/nnUNet Data/multiclass_masks/msk_004.nii.gz\")\n",
    "nii_img_scan_aug = nib.load(\"D:/MRI - Tairawhiti (User POV)/nnUNet Data/multiclass_masks/msk_004_aug0.nii.gz\")\n",
    "\n",
    "scan_data = nii_img_scan.get_fdata()\n",
    "scan_aug_data = nii_img_scan_aug.get_fdata()\n",
    "print(scan_aug_data.shape)\n",
    "\n",
    "# np.savetxt('scan.txt', scan_data[450,:,:,0], fmt=\"%d\", delimiter=\",\")\n",
    "# np.savetxt('scan_aug.txt', scan_aug_data[450,:,:,0], fmt=\"%d\", delimiter=\",\")\n",
    "\n",
    "scan_data = scan_data[450,:,:,0]\n",
    "unique_numbers = find_unique_numbers(scan_data)\n",
    "print('Not Augmented: ', unique_numbers)\n",
    "\n",
    "scan_aug_data = scan_aug_data[450,:,:,0]\n",
    "unique_numbers = find_unique_numbers(scan_aug_data)\n",
    "print('Augmented: ', unique_numbers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1_RR_fibula_15A.nii', '1_R_femur_15A.nii', '1_R_tibia_15A.nii']\n"
     ]
    }
   ],
   "source": [
    "# paitent_id = segmasks_fnames[0]\n",
    "import os \n",
    "scan_data_folders = ['1_AutoBind_WaterWATER_450_15A']\n",
    "paitent_id = (scan_data_folders[0].split('_'))[-1]\n",
    "scans_path = 'D:/MRI - Tairawhiti (User POV)'\n",
    "segmasks = []\n",
    "files = os.listdir(('{}/Raw NIFITI Segmentation Masks (3D Slicer Output)').format(scans_path))\n",
    "for file in files:\n",
    "    file_id = file.split('_')[-1]\n",
    "    file_id = file_id.split('.')[0]\n",
    "    if (file_id == paitent_id):\n",
    "        segmasks.append(file)\n",
    "print(segmasks)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preprocessing Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "Loading and preprocessing training data...\n",
      "------------------------------\n",
      "Patient Scan Data Folders Included in Run:  ['6_AutoBindWATER_650_9B']\n",
      "\n",
      "\n",
      "Segmentation Mask:  6_RR_fibula_9B\n",
      "Patient Scan Data:  6_AutoBindWATER_650_9B\n",
      "AOI Slice Start:  745\n",
      "AOI Slice End:  958\n",
      "AOI Slice Range:  214\n",
      "\n",
      "\n",
      "Scans Normalized! [0-1]\n",
      "Masks Binarised! [0,1]\n",
      "\n",
      "\n",
      "Training Scans Input Shape:  (1015, 512, 512, 1)\n",
      "Training Masks Input Shape:  (1015, 512, 512, 1)\n",
      "Conversion to NIfTI complete with shape:  (1015, 512, 512, 1)\n",
      "\n",
      "\n",
      "Successfully Exported Preprocessed Data!\n",
      "Regions of Interest for Segmentation:  ['FEMUR', 'FIBULA', 'TIBIA']\n",
      "Number of Segmentation Classes:  3\n",
      "(467, 289, 317, 0)\n",
      "(958, 331, 363, 0)\n",
      "AOI Slice Start (Final with Thresholding):  447\n",
      "AOI Slice End (Final with Thresholding):  978\n",
      "Final Training Scans Input Shape:  (531, 512, 512, 1)\n",
      "Final Training Masks Input Shape:  (531, 512, 512, 1)\n",
      "\n",
      "\n",
      "MSK Multiclass Mask Made Using ['FEMUR', 'FIBULA', 'TIBIA'], Saved To D:/MRI - Tairawhiti/nnUNet Data/multiclass_masks/msk_006.nii.gz, and Region of Interest Slice Thresholding = True !\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\GGPC\\OneDrive\\Desktop\\Part 4 Project\\Part4Project\\preprocessing.py:78: RuntimeWarning: invalid value encountered in true_divide\n",
      "  image2 = image2 / np.max(image2)\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "Completed Preprocessing Stage!\n",
      "------------------------------\n"
     ]
    }
   ],
   "source": [
    "from preprocessing import preprocessing, VisualValidationMSK\n",
    "from image_preprocessing import image_cropping, uniform_cropping, uniform_resizing\n",
    "from writeout_dataset import Export2CompressedNifiti\n",
    "from MSKMulticlass import CreateMasks4MulticlassMSK\n",
    "\n",
    "print('-'*30)\n",
    "print('Loading and preprocessing training data...')\n",
    "print('-'*30)\n",
    "\n",
    "# \\\\files.auckland.ac.nz\\research\\resmed202100086-tws ----> Address for raw data\n",
    "# \\\\files.auckland.ac.nz\\research\\resabi202200010-Friedlander\n",
    "total_slices_raw_data = 200\n",
    "DataOnlyAOI = False\n",
    "ExportDatasets = True\n",
    "Cropping = False\n",
    "# Only set multiclassSegmentation to True once all the AOI masks are preprocessed or its onto the final AOI\n",
    "multiclassSegmentation = True\n",
    "scans_path = 'D:/MRI - Tairawhiti'\n",
    "\n",
    "# segmasks_fnames = ['1_R_tibia_15A', '2_R_tibia_16A', '3_R_tibia_4A', '4_R_tibia_5A', '5_R_tibia_8A', '6_R_tibia_9B', '7_R_tibia_10B', '8_R_tibia_12A', '9_R_tibia_11A', '10_R_tibia_13A']\n",
    "# segmasks_fnames = ['1_R_femur_15A', '2_R_femur_16A', '3_R_femur_4A', '4_R_femur_5A', '5_R_femur_8A', '6_R_femur_9B']\n",
    "# segmasks_fnames = ['1_RR_fibula_15A', '2_RR_fibula_16A', '3_RR_fibula_4A', '4_RR_fibula_5A', '5_RR_fibula_8A', '6_RR_fibula_9B']\n",
    "# scan_data_folders = ['1_AutoBind_WaterWATER_450_15A', '2_AutoBindWATER_450_16A', '3_AutoBindWATER_650_4A', '4_AutoBindWATER_750_5A', '5_AutoBindWATER_1050_8A', '6_AutoBindWATER_650_9B', '7_AutoBindWATER_450_10B',\n",
    "#                     '8_AutoBindWATER_450_12A', '9_AutoBindWATER_550_11A', '10_AutoBindWATER_450_13A']\n",
    "\n",
    "segmasks_fnames = ['6_RR_fibula_9B']\n",
    "scan_data_folders = ['6_AutoBindWATER_650_9B']\n",
    "colab_fname = ['FIBULA_006']\n",
    "\n",
    "imgs_train, imgs_mask_train, median_aoi_index = preprocessing(scans_path, segmasks_fnames, scan_data_folders, total_slices_raw_data, DataOnlyAOI)\n",
    "\n",
    "if (Cropping == True):\n",
    "    imgs_train, imgs_mask_train = image_cropping(imgs_train, top = 56, bottom = 56, left = 56, right = 56), image_cropping(imgs_mask_train, top = 56, bottom = 56, left = 56, right = 56)\n",
    "    print(\"\\n\")\n",
    "    print('Preprocessed Final Training Image Input Shape (After Image Preprocessing): ', imgs_train.shape)\n",
    "    print('Preprocessed Final Training Mask Input Shape (After Image Preprocessing): ', imgs_mask_train.shape)\n",
    "\n",
    "if (ExportDatasets == True):\n",
    "    Export2CompressedNifiti(imgs_train, scans_path, colab_fname, imgs_mask_train)\n",
    "\n",
    "if (multiclassSegmentation == True):\n",
    "    # User Input\n",
    "    individual_mask_directory = 'D:/MRI - Tairawhiti/nnUNet Data/masks'\n",
    "    multiclass_mask_output_dir = 'D:/MRI - Tairawhiti/nnUNet Data/multiclass_masks'\n",
    "    scan_dir = 'D:/MRI - Tairawhiti/nnUNet Data/scans'\n",
    "    \n",
    "    TIBIA_encoding = 1\n",
    "    FEMUR_encoding = 2\n",
    "    FIBULA_encoding = 3\n",
    "    mask_index = int((colab_fname[0].split('_'))[1])\n",
    "    AOIThresholding = True\n",
    "\n",
    "    CreateMasks4MulticlassMSK(scan_dir, individual_mask_directory, mask_index, TIBIA_encoding, FEMUR_encoding, FIBULA_encoding, multiclass_mask_output_dir, AOIThresholding)\n",
    "\n",
    "\n",
    "# Validate Scans & Masks (2D Slice)\n",
    "if (multiclass_mask_output_dir != None):\n",
    "    VisualValidationMSK(colab_fname, 400, 400, multiclass_mask_output_dir, mask_index)\n",
    "else:\n",
    "    VisualValidationMSK(colab_fname, 400, 400, False, mask_index)\n",
    "\n",
    "print('-'*30)\n",
    "print('Completed Preprocessing Stage!')\n",
    "print('-'*30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Multi-Class Segmentation Task Data Preparation!\n",
      "Regions of Interest for Segmentation:  ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA']\n",
      "Number of Segmentation Classes:  4\n",
      "(21, 160, 44, 0)\n",
      "(854, 104, 177, 0)\n",
      "AOI Slice Start (Final with Thresholding):  21\n",
      "AOI Slice End (Final with Thresholding):  854\n",
      "Final Training Scans Input Shape:  (833, 277, 230, 1)\n",
      "Final Training Masks Input Shape:  (833, 277, 230, 1)\n",
      "\n",
      "\n",
      "MSK Multiclass Mask Made Using ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA'], Saved To D:/MRI - Friedlander/nnUNet Data/multiclass_masks/msk_001.nii.gz, and Region of Interest Slice Thresholding = True !\n",
      "Regions of Interest for Segmentation:  ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA']\n",
      "Number of Segmentation Classes:  4\n",
      "(30, 154, 51, 0)\n",
      "(796, 115, 174, 0)\n",
      "AOI Slice Start (Final with Thresholding):  30\n",
      "AOI Slice End (Final with Thresholding):  796\n",
      "Final Training Scans Input Shape:  (766, 260, 232, 1)\n",
      "Final Training Masks Input Shape:  (766, 260, 232, 1)\n",
      "\n",
      "\n",
      "MSK Multiclass Mask Made Using ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA'], Saved To D:/MRI - Friedlander/nnUNet Data/multiclass_masks/msk_002.nii.gz, and Region of Interest Slice Thresholding = True !\n",
      "Regions of Interest for Segmentation:  ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA']\n",
      "Number of Segmentation Classes:  4\n",
      "(11, 109, 22, 0)\n",
      "(915, 155, 81, 0)\n",
      "AOI Slice Start (Final with Thresholding):  11\n",
      "AOI Slice End (Final with Thresholding):  915\n",
      "Final Training Scans Input Shape:  (904, 216, 120, 1)\n",
      "Final Training Masks Input Shape:  (904, 216, 120, 1)\n",
      "\n",
      "\n",
      "MSK Multiclass Mask Made Using ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA'], Saved To D:/MRI - Friedlander/nnUNet Data/multiclass_masks/msk_003.nii.gz, and Region of Interest Slice Thresholding = True !\n",
      "Regions of Interest for Segmentation:  ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA']\n",
      "Number of Segmentation Classes:  4\n",
      "(42, 121, 67, 0)\n",
      "(689, 130, 78, 0)\n",
      "AOI Slice Start (Final with Thresholding):  42\n",
      "AOI Slice End (Final with Thresholding):  689\n",
      "Final Training Scans Input Shape:  (647, 217, 118, 1)\n",
      "Final Training Masks Input Shape:  (647, 217, 118, 1)\n",
      "\n",
      "\n",
      "MSK Multiclass Mask Made Using ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA'], Saved To D:/MRI - Friedlander/nnUNet Data/multiclass_masks/msk_004.nii.gz, and Region of Interest Slice Thresholding = True !\n",
      "Regions of Interest for Segmentation:  ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA']\n",
      "Number of Segmentation Classes:  4\n",
      "(124, 120, 40, 0)\n",
      "(922, 161, 88, 0)\n",
      "AOI Slice Start (Final with Thresholding):  124\n",
      "AOI Slice End (Final with Thresholding):  922\n",
      "Final Training Scans Input Shape:  (798, 218, 120, 1)\n",
      "Final Training Masks Input Shape:  (798, 218, 120, 1)\n",
      "\n",
      "\n",
      "MSK Multiclass Mask Made Using ['FEMUR', 'FIBULA', 'PELVIS', 'TIBIA'], Saved To D:/MRI - Friedlander/nnUNet Data/multiclass_masks/msk_005.nii.gz, and Region of Interest Slice Thresholding = True !\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage.transform import resize\n",
    "from MSKMulticlass import CreateMasks4MulticlassMSK\n",
    "\n",
    "def superimpose_images(image1, image2):\n",
    "    image1 = image1 / np.max(image1)\n",
    "    image2 = image2 / np.max(image2)\n",
    "    alpha = 0.5\n",
    "    superimposed_image = alpha * image1 + (1 - alpha) * image2\n",
    "    return superimposed_image\n",
    "\n",
    "def uniform_resizing(images, new_size):\n",
    "    \"\"\"\n",
    "    Applies uniform resizing to an array of images. ### bicubic_interpolation\n",
    "\n",
    "    Parameters:\n",
    "        images (numpy.ndarray): Input array of images with shape (x, 512, 512, 1).\n",
    "        new_size (tuple or int): The desired size of the resized images. If it's an int,\n",
    "                                 the new size will be (new_size, new_size).\n",
    "\n",
    "    Returns:\n",
    "        numpy.ndarray: Array of resized images with shape (x, new_height, new_width, 1).\n",
    "    \"\"\"\n",
    "    x, height, width, _ = images.shape\n",
    "\n",
    "    if isinstance(new_size, int):\n",
    "        new_height, new_width = new_size, new_size\n",
    "    elif isinstance(new_size, tuple) and len(new_size) == 2:\n",
    "        new_height, new_width = new_size\n",
    "    else:\n",
    "        raise ValueError(\"Invalid new_size argument. It should be an int or a tuple of two ints.\")\n",
    "\n",
    "    resized_images = np.zeros((x, new_height, new_width, 1))\n",
    "\n",
    "    for i in range(x):\n",
    "        resized_images[i, ..., 0] = resize(images[i, ..., 0], (new_height, new_width), mode='reflect', order=1)\n",
    "        # resized_images[i, ..., 0] = resize(images[i, ..., 0], (new_height, new_width), mode='constant', preserve_range=True)\n",
    "\n",
    "    return resized_images\n",
    "\n",
    "\n",
    "scans_path = 'D:/MRI - Friedlander'\n",
    "\n",
    "\n",
    "        \n",
    "\n",
    "# slice = 750\n",
    "# aoi = 'FEMUR'\n",
    "\n",
    "# tibia_PLB_02 = nib.load(('{}/{}/{}/{}_001.nii.gz').format(scans_path, 'nnUNet Data/masks', aoi, aoi))\n",
    "# tibia_PLB_02_data = tibia_PLB_02.get_fdata()\n",
    "# print(tibia_PLB_02_data.shape)\n",
    "\n",
    "# PLB_02 = nib.load(('{}/{}/msk_001.nii.gz').format(scans_path, 'nnUNet Data/scans'))\n",
    "# PLB_02_data = PLB_02.get_fdata()\n",
    "# # nii_img_mask = nib.Nifti1Image(PLB_02_data, affine=np.eye(4))\n",
    "# # output_file_path = ('D:/MRI - Tairawhiti/PLB-02.nii.gz')\n",
    "# # nib.save(nii_img_mask, output_file_path)   \n",
    "# print(PLB_02_data.shape)\n",
    "\n",
    "# image1 = PLB_02_data[:,slice,:]\n",
    "# image2 = tibia_PLB_02_data[:,slice,:]\n",
    "# superimposed_image = superimpose_images(image1, image2)\n",
    "\n",
    "\n",
    "# plt.imshow(PLB_02_data[:,slice,:], cmap='gray')\n",
    "# plt.axis('off')\n",
    "# plt.show()\n",
    "\n",
    "# plt.imshow(tibia_PLB_02_data[:,slice,:], cmap='gray')\n",
    "# plt.axis('off')\n",
    "# plt.show()\n",
    "\n",
    "# plt.imshow(superimposed_image, cmap='gray')\n",
    "# plt.axis('off')\n",
    "# plt.show()\n",
    "\n",
    "\n",
    "# PLB_02_data_512 = np.expand_dims(PLB_02_data, axis = -1)\n",
    "# PLB_02_data_512 = PLB_02_data_512.transpose(1, 0, 2, 3)\n",
    "# PLB_02_data_512 = uniform_resizing(PLB_02_data_512, 256)\n",
    "\n",
    "# tibia_PLB_02_data_512 = np.expand_dims(tibia_PLB_02_data, axis = -1)\n",
    "# tibia_PLB_02_data_512 = tibia_PLB_02_data_512.transpose(1, 0, 2, 3)\n",
    "# tibia_PLB_02_data_512 = uniform_resizing(tibia_PLB_02_data_512, 256)\n",
    "\n",
    "# plt.imshow(PLB_02_data_512[slice,:,:,0], cmap='gray')\n",
    "# plt.axis('off')\n",
    "# plt.show()\n",
    "\n",
    "# plt.imshow(tibia_PLB_02_data_512[slice,:,:,0], cmap='gray')\n",
    "# plt.axis('off')\n",
    "# plt.show()\n",
    "\n",
    "# superimposed_image = superimpose_images(PLB_02_data_512[slice,:,:,0], tibia_PLB_02_data_512[slice,:,:,0])\n",
    "\n",
    "# plt.imshow(superimposed_image, cmap='gray')\n",
    "# plt.axis('off')\n",
    "# plt.show()\n",
    "\n",
    "# nii_img_mask = nib.Nifti1Image(PLB_02_data_512, affine=np.eye(4))\n",
    "# output_file_path = ('D:/MRI - Tairawhiti/PLB_02_data_512.nii.gz')\n",
    "# nib.save(nii_img_mask, output_file_path)  \n",
    "\n",
    "# nii_img_mask = nib.Nifti1Image(tibia_PLB_02_data_512, affine=np.eye(4))\n",
    "# output_file_path = ('D:/MRI - Tairawhiti/tibia_PLB_02_data_512.nii.gz')\n",
    "# nib.save(nii_img_mask, output_file_path)   \n",
    "\n",
    "individual_mask_directory = ('{}/nnUNet Data/masks').format(scans_path)\n",
    "multiclass_mask_output_dir = ('{}/nnUNet Data/multiclass_masks').format(scans_path)\n",
    "scan_dir = ('{}/nnUNet Data/scans').format(scans_path)\n",
    "mask_index = [1,2,3,4,5]\n",
    "\n",
    "TIBIA_encoding = 1\n",
    "FEMUR_encoding = 2\n",
    "FIBULA_encoding = 3\n",
    "PELVIS_encoding = 4\n",
    "AOIThresholding = True\n",
    "FriedLanderDataset = True\n",
    "\n",
    "print('Multi-Class Segmentation Task Data Preparation!')\n",
    "\n",
    "for i in range (len(mask_index)):\n",
    "    CreateMasks4MulticlassMSK(scan_dir, individual_mask_directory, mask_index[i], TIBIA_encoding, FEMUR_encoding, FIBULA_encoding, PELVIS_encoding, multiclass_mask_output_dir, AOIThresholding, FriedLanderDataset)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PLB_02 = nib.load(('{}/{}/msk_001.nii.gz').format(scans_path, 'nnUNet Data/scans'))\n",
    "PLB_02_data = PLB_02.get_fdata()\n",
    "\n",
    "len(PLB_02_data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(457, 512, 512, 1)\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def superimpose_images(image1, image2):\n",
    "    image1 = image1 / np.max(image1)\n",
    "    image2 = image2 / np.max(image2)\n",
    "    alpha = 0.5\n",
    "    superimposed_image = alpha * image1 + (1 - alpha) * image2\n",
    "    return superimposed_image\n",
    "\n",
    "\n",
    "slice_idx = 500\n",
    "\n",
    "# nii_img_scan = nib.load('D:/MRI - Tairawhiti/nnUNet Data/scans/FEMUR_002.nii.gz')\n",
    "# nii_img_mask = nib.load('D:/MRI - Tairawhiti/nnUNet Data/masks/FEMUR/FEMUR_002.nii.gz')\n",
    "# mask_data = nii_img_mask.get_fdata()\n",
    "# scan_data = nii_img_scan.get_fdata()\n",
    "\n",
    "# image1 = scan_data[slice_idx, :, :, :]\n",
    "# image2 = mask_data[slice_idx, :, :, :]\n",
    "# superimposed_image = superimpose_images(image1, image2)\n",
    "# plt.imshow(superimposed_image, cmap='gray')\n",
    "# plt.axis('off')\n",
    "# plt.show()\n",
    "\n",
    "nii_img_mask = nib.load('D:/MRI - Tairawhiti (User POV)/nnUNet Data/scans/msk_001.nii.gz')\n",
    "mask_data = nii_img_mask.get_fdata()\n",
    "print(mask_data.shape)\n",
    "image3 = mask_data[400, :, :, 0]\n",
    "plt.imshow(image3, cmap='gray')\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FIBULA_020\n"
     ]
    }
   ],
   "source": [
    "def transform_string(input_string):\n",
    "    parts = input_string.split('_')\n",
    "    \n",
    "    if len(parts) < 2:\n",
    "        return None\n",
    "    \n",
    "    prefix = parts[-2].upper()\n",
    "    \n",
    "    try:\n",
    "        number_part = int(parts[0])\n",
    "    except ValueError:\n",
    "        return None\n",
    "    \n",
    "    new_string = f'{prefix}_{number_part:03d}'\n",
    "    return new_string\n",
    "\n",
    "input_string = '20_RR_fibula_100C'\n",
    "output_string = transform_string(input_string)\n",
    "print(output_string)  # Output: 'FIBULA_006'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Aligning the Compressed Nifti Scans & Masks (Coordinate System, Direction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tibia_009\n"
     ]
    },
    {
     "ename": "OrientationError",
     "evalue": "Data array has fewer dimensions than orientation",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mOrientationError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[5], line 70\u001b[0m\n\u001b[0;32m     68\u001b[0m nii_file_path \u001b[39m=\u001b[39m (\u001b[39m'\u001b[39m\u001b[39mD:/MRI - Tairawhiti/nnUNet Data/scans/\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.nii.gz\u001b[39m\u001b[39m'\u001b[39m)\u001b[39m.\u001b[39mformat(fname)\n\u001b[0;32m     69\u001b[0m nii_img \u001b[39m=\u001b[39m nib\u001b[39m.\u001b[39mload(nii_file_path)\n\u001b[1;32m---> 70\u001b[0m reoriented_img \u001b[39m=\u001b[39m nib\u001b[39m.\u001b[39;49morientations\u001b[39m.\u001b[39;49mapply_orientation(nii_img, target_orientation)\n\u001b[0;32m     71\u001b[0m output_file_path \u001b[39m=\u001b[39m (\u001b[39m'\u001b[39m\u001b[39mD:/MRI - Tairawhiti/nnUNet Data/scans/\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.nii.gz\u001b[39m\u001b[39m'\u001b[39m)\u001b[39m.\u001b[39mformat(fname)\n\u001b[0;32m     72\u001b[0m nib\u001b[39m.\u001b[39msave(modified_nii_img, output_file_path)\n",
      "File \u001b[1;32mc:\\Users\\GGPC\\anaconda3\\envs\\Py39-CNN\\lib\\site-packages\\nibabel\\orientations.py:155\u001b[0m, in \u001b[0;36mapply_orientation\u001b[1;34m(arr, ornt)\u001b[0m\n\u001b[0;32m    153\u001b[0m n \u001b[39m=\u001b[39m ornt\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[0;32m    154\u001b[0m \u001b[39mif\u001b[39;00m t_arr\u001b[39m.\u001b[39mndim \u001b[39m<\u001b[39m n:\n\u001b[1;32m--> 155\u001b[0m     \u001b[39mraise\u001b[39;00m OrientationError(\u001b[39m'\u001b[39m\u001b[39mData array has fewer dimensions than orientation\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m    156\u001b[0m \u001b[39m# no coordinates can be dropped for applying the orientations\u001b[39;00m\n\u001b[0;32m    157\u001b[0m \u001b[39mif\u001b[39;00m np\u001b[39m.\u001b[39many(np\u001b[39m.\u001b[39misnan(ornt[:, \u001b[39m0\u001b[39m])):\n",
      "\u001b[1;31mOrientationError\u001b[0m: Data array has fewer dimensions than orientation"
     ]
    }
   ],
   "source": [
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fnames= ['Tibia_009', 'Tibia_010', 'Tibia_015']\n",
    "translate_y = False\n",
    "flip_direction = False\n",
    "\n",
    "for fname in fnames:\n",
    "    print(fname)\n",
    "    \n",
    "    nii_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/masks/{}.nii.gz').format(fname)\n",
    "    nii_img = nib.load(nii_file_path)\n",
    "    affine_matrix = nii_img.affine\n",
    "    direction_mask = affine_matrix[:3, :3]\n",
    "    origin_coordinates_mask = affine_matrix[:3, 3]\n",
    "    # print(\"Mask Origin Coordinates:\", origin_coordinates_mask)\n",
    "    # print(\"Mask Direction Matrix:\", direction)\n",
    "    # print('\\n')\n",
    "    \n",
    "    nii_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname)\n",
    "    nii_img = nib.load(nii_file_path)\n",
    "    affine_matrix = nii_img.affine\n",
    "    direction_scan = affine_matrix[:3, :3]\n",
    "    origin_coordinates_scan = affine_matrix[:3, 3]\n",
    "    # print(\"Scan Origin Coordinates:\", origin_coordinates)\n",
    "    # print(\"Scan Direction Matrix:\", direction)\n",
    "    # print('\\n')\n",
    "\n",
    "    if (translate_y == True):\n",
    "        nii_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname)\n",
    "        nii_img = nib.load(nii_file_path)\n",
    "        affine_matrix = nii_img.affine\n",
    "\n",
    "        affine_matrix[:3, 3] = origin_coordinates_mask\n",
    "        modified_nii_img = nib.Nifti1Image(nii_img.get_fdata(), affine_matrix)\n",
    "\n",
    "        output_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname)\n",
    "        nib.save(modified_nii_img, output_file_path)\n",
    "        # print(\"Scan Origin coordinates modified and file rewritten!\")\n",
    "\n",
    "        nii_img = nib.load(nii_file_path)\n",
    "        affine_matrix = nii_img.affine\n",
    "        origin_coordinates_scan = affine_matrix[:3, 3]\n",
    "        # print(\"Scan Origin Coordinates (Modified):\", origin_coordinates)\n",
    "        # print('\\n')\n",
    "\n",
    "    if (flip_direction == True):\n",
    "        nii_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname)\n",
    "        nii_img = nib.load(nii_file_path)\n",
    "        direction = affine_matrix[:3, :3]\n",
    "\n",
    "        affine_matrix[:3, :3] = direction_mask\n",
    "        modified_nii_img = nib.Nifti1Image(nii_img.get_fdata(), affine_matrix)\n",
    "\n",
    "        output_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname)\n",
    "        nib.save(modified_nii_img, output_file_path)\n",
    "        # print(\"Scan direction modified and file rewritten!\")\n",
    "\n",
    "        nii_img = nib.load(nii_file_path)\n",
    "        affine_matrix = nii_img.affine\n",
    "        direction_scan = affine_matrix[:3, :3]\n",
    "        # print(\"Scan Direction (Modified):\", direction_scan)\n",
    "        # print('\\n')\n",
    "\n",
    "\n",
    "    target_orientation = ('R', 'A', 'S')\n",
    "    nii_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname)\n",
    "    nii_img = nib.load(nii_file_path)\n",
    "    reoriented_img = nib.orientations.apply_orientation(nii_img, target_orientation)\n",
    "    output_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname)\n",
    "    nib.save(modified_nii_img, output_file_path)\n",
    "\n",
    "    target_orientation = ('R', 'A', 'S')\n",
    "    nii_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/masks/{}.nii.gz').format(fname)\n",
    "    nii_img = nib.load(nii_file_path)\n",
    "    reoriented_img = nib.orientations.apply_orientation(nii_img, target_orientation)\n",
    "    output_file_path = ('D:/MRI - Tairawhiti/nnUNet Data/masks/{}.nii.gz').format(fname)\n",
    "    nib.save(modified_nii_img, output_file_path)    \n",
    "\n",
    "    if (np.array_equal(np.array(origin_coordinates_scan), np.array(origin_coordinates_mask)) == True):\n",
    "        print(('Mask and Scan Coordinate Systems Matching! ({})').format(fname))\n",
    "    else:\n",
    "        print('ERROR: Mask and Scan Coordinate Systems Mismatch!')\n",
    "    if (np.array_equal(np.array(direction_scan), np.array(direction_mask)) == True):\n",
    "        print(('Mask and Scan Directions Matching! ({})').format(fname)) \n",
    "    else:\n",
    "        print('ERROR: Mask and Scan Direction Mismatch!')\n",
    "    print('\\n')\n",
    "\n",
    "\n",
    "    slice_idx = 100\n",
    "\n",
    "    nii_img = nib.load(('D:/MRI - Tairawhiti/nnUNet Data/scans/{}.nii.gz').format(fname))\n",
    "    nii_data = nii_img.get_fdata()\n",
    "    plt.imshow(nii_data[:, :, slice_idx], cmap='gray')\n",
    "    plt.title('Scan Slice Visualization')\n",
    "    plt.colorbar()\n",
    "    plt.show()\n",
    "\n",
    "    nii_img = nib.load(('D:/MRI - Tairawhiti/nnUNet Data/masks/{}.nii.gz').format(fname))\n",
    "    nii_data = nii_img.get_fdata()\n",
    "    plt.imshow(nii_data[:, :, slice_idx], cmap='gray')\n",
    "    plt.title('Mask Slice Visualization')\n",
    "    plt.colorbar()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reversal within each NIfTI complete.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pydicom\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "\n",
    "fnames = ['3_AutoBindWATER_650_4A', '4_AutoBindWATER_750_5A', '5_AutoBindWATER_1050_8A', '6_AutoBindWATER_650_9B', '7_AutoBindWATER_450_10B']\n",
    "dicom2nifiti = False\n",
    "reverseMasksNifiti = True\n",
    "\n",
    "\n",
    "if (dicom2nifiti == True):\n",
    "    for fname in fnames:\n",
    "        dicom_dir = ('D:/MRI - Tairawhiti/{}').format(fname)\n",
    "        \n",
    "        dicom_files = sorted([f for f in os.listdir(dicom_dir) if f.endswith('.dcm')])\n",
    "        dicom_data = [pydicom.dcmread(os.path.join(dicom_dir, f)) for f in dicom_files]\n",
    "        orientation = dicom_data[0].ImageOrientationPatient\n",
    "        pixel_data = [d.pixel_array for d in dicom_data]\n",
    "\n",
    "        volume = np.stack(pixel_data, axis=-1)\n",
    "\n",
    "        nii_img = nib.Nifti1Image(volume, np.eye(4))\n",
    "\n",
    "        orientation_matrix = np.array(orientation).reshape(2, 3)\n",
    "        affine_matrix = np.eye(4)\n",
    "        affine_matrix[:2, :3] = orientation_matrix\n",
    "\n",
    "        nii_img.set_qform(affine_matrix)\n",
    "\n",
    "        nii_output_path = ('D:/MRI - Tairawhiti/{}.nii.gz').format(fname)\n",
    "        nib.save(nii_img, nii_output_path)\n",
    "\n",
    "        print((\"Conversion complete for {}!\").format(fname))\n",
    "\n",
    "\n",
    "if (reverseMasksNifiti == True):\n",
    "    nii_dir = 'D:/MRI - Tairawhiti/nnUNet Data/masks'\n",
    "\n",
    "    nii_files = sorted([f for f in os.listdir(nii_dir) if f.endswith('.nii.gz')])\n",
    "\n",
    "    output_dir = 'D:/MRI - Tairawhiti/nnUNet Data/masks/idk'\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "    for nii_file in nii_files:\n",
    "        nii_path = os.path.join(nii_dir, nii_file)\n",
    "        nii_img = nib.load(nii_path)\n",
    "\n",
    "        # Get the data array and reverse the order of elements\n",
    "        nii_data = nii_img.get_fdata()\n",
    "        reversed_data = np.flip(nii_data, axis=-1)\n",
    "\n",
    "        # Create a new NIfTI image with the reversed data\n",
    "        reversed_nii_img = nib.Nifti1Image(reversed_data, nii_img.affine)\n",
    "\n",
    "        # Save the reversed NIfTI image\n",
    "        reversed_output_path = os.path.join(output_dir, f\"reversed_{nii_file}\")\n",
    "        nib.save(reversed_nii_img, reversed_output_path)\n",
    "\n",
    "    print(\"Reversal within each NIfTI complete.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Superimposed Mask Visualisation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0074625\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from preprocessing import preprocessing\n",
    "import numpy as np\n",
    "import SimpleITK as sitk\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def superimpose_images(image1, image2):\n",
    "    image1 = image1 / np.max(image1)\n",
    "    image2 = image2 / np.max(image2)\n",
    "    alpha = 0.5\n",
    "    superimposed_image = alpha * image1 + (1 - alpha) * image2\n",
    "    return superimposed_image\n",
    "\n",
    "slice = 200\n",
    "image1 = imgs_train[slice]\n",
    "image2 = imgs_mask_train[slice]\n",
    "superimposed_image = superimpose_images(image1, image2)\n",
    "print(np.mean(image2))\n",
    "plt.imshow(superimposed_image, cmap='gray')\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data Augmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "Data Augmentation Starting...\n",
      "------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 840/840 [00:12<00:00, 65.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of Augmentation per Input:  3\n",
      "\n",
      "\n",
      "Shape of Augmented Images:  (2520, 512, 512, 1)\n",
      "Shape of Augmented Masks:  (2520, 512, 512, 1)\n",
      "\n",
      "\n",
      "Shape of Training Image Data (Before):  (840, 512, 512, 1)\n",
      "Shape of Training Image Masks (Before):  (840, 512, 512, 1)\n",
      "\n",
      "\n",
      "Shape of Training Image Data (After):  (3360, 512, 512, 1)\n",
      "Shape of Training Image Masks (After):  (3360, 512, 512, 1)\n",
      "------------------------------\n",
      "Completed Data Augmentation Stage!\n",
      "------------------------------\n"
     ]
    }
   ],
   "source": [
    "from data_augmentation import DataAugmentation\n",
    "\n",
    "print('-'*30)\n",
    "print('Data Augmentation Starting...')\n",
    "print('-'*30)\n",
    "\n",
    "num_augmentations = 3\n",
    "augmented_images_train, augmented_masks_train = DataAugmentation(imgs_train, imgs_mask_train, num_augmentations)\n",
    "\n",
    "augmented_images_train = np.expand_dims(augmented_images_train, axis=-1)\n",
    "augmented_masks_train = np.expand_dims(augmented_masks_train, axis=-1)\n",
    "\n",
    "print('Number of Augmentation per Input: ', num_augmentations)\n",
    "print('\\n')\n",
    "print('Shape of Augmented Images: ', augmented_images_train.shape)\n",
    "print('Shape of Augmented Masks: ', augmented_masks_train.shape)\n",
    "print('\\n')\n",
    "print('Shape of Training Image Data (Before): ', imgs_train.shape)\n",
    "print('Shape of Training Image Masks (Before): ', imgs_mask_train.shape)\n",
    "print('\\n')\n",
    "\n",
    "if True:\n",
    "    imgs_train = np.concatenate((imgs_train, augmented_images_train), axis=0)\n",
    "    imgs_mask_train = np.concatenate((imgs_mask_train, augmented_masks_train), axis=0)\n",
    "\n",
    "print('Shape of Training Image Data (After): ', imgs_train.shape)\n",
    "print('Shape of Training Image Masks (After): ', imgs_mask_train.shape)\n",
    "\n",
    "print('-'*30)\n",
    "print('Completed Data Augmentation Stage!')\n",
    "print('-'*30)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Exporting out Labelled Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "Exporting Scan & Mask Data...\n",
      "------------------------------\n",
      "------------------------------\n",
      "Completed Exporting Stage!\n",
      "------------------------------\n"
     ]
    }
   ],
   "source": [
    "# Exporting Training Scan and Mask Data\n",
    "from writeout_dataset import WriteOutTextFile, WriteOutImagePNGFiles, ReadInDatasets, ReadInDatasetNPY\n",
    "\n",
    "print('-'*30)\n",
    "print('Exporting Scan & Mask Data...')\n",
    "print('-'*30)\n",
    "\n",
    "# starting_slice = 70\n",
    "# save_directory_data_txt = 'D:/P4P Model Data/txt/Data_Tibia(Collab Sample)'\n",
    "# WriteOutTextFile(imgs_train[:, :, :, 0], save_directory_data_txt, starting_slice)  \n",
    "# save_directory_mask_txt = 'D:/P4P Model Data/txt/Mask_Tibia(Collab Sample)'\n",
    "# WriteOutTextFile(imgs_mask_train[:, :, :, 0], save_directory_mask_txt, starting_slice)   \n",
    "\n",
    "# np.save(\"D:/MRI - Tairawhiti/imgs_train_tibia.npy\", imgs_train)\n",
    "# np.save(\"D:/MRI - Tairawhiti/imgs_mask_train_tibia.npy\", imgs_train)\n",
    "\n",
    "imgs_train_scan_tibia = ReadInDatasetNPY('D:/MRI - Tairawhiti/Model Input Masks Tibia (Google Colab)')\n",
    "\n",
    "print('-'*30)\n",
    "print('Completed Exporting Stage!')\n",
    "print('-'*30)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2DUNet Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define U-Net model\n",
    "def unet_model(input_shape):\n",
    "    inputs = Input(input_shape)\n",
    "\n",
    "    # Encoder\n",
    "    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)\n",
    "    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)\n",
    "    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n",
    "\n",
    "    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)\n",
    "    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)\n",
    "    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n",
    "\n",
    "    conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)\n",
    "    conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)\n",
    "    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n",
    "\n",
    "    conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)\n",
    "    conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)\n",
    "    drop4 = Dropout(0.5)(conv4)\n",
    "    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n",
    "\n",
    "    # Bottleneck\n",
    "    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)\n",
    "    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)\n",
    "    drop5 = Dropout(0.5)(conv5)\n",
    "\n",
    "    # Decoder\n",
    "    up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5))\n",
    "    merge6 = concatenate([drop4, up6], axis=3)\n",
    "    conv6 = Conv2D(512, 3, activation='relu', padding='same')(merge6)\n",
    "    conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)\n",
    "\n",
    "    up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))\n",
    "    merge7 = concatenate([conv3, up7], axis=3)\n",
    "    conv7 = Conv2D(256, 3, activation='relu', padding='same')(merge7)\n",
    "    conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)\n",
    "\n",
    "    up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7))\n",
    "    merge8 = concatenate([conv2, up8], axis=3)\n",
    "    conv8 = Conv2D(128, 3, activation='relu', padding='same')(merge8)\n",
    "    conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)\n",
    "\n",
    "    up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8))\n",
    "    merge9 = concatenate([conv1, up9], axis=3)\n",
    "    conv9 = Conv2D(64, 3, activation='relu', padding='same')(merge9)\n",
    "    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)\n",
    "\n",
    "    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)\n",
    "\n",
    "    model = Model(inputs=inputs, outputs=outputs)\n",
    "\n",
    "    return model\n",
    "\n",
    "# Dice Coefficient Loss Function\n",
    "def dice_loss(y_true, y_pred):\n",
    "    smooth = 1e-5  # Adding a small constant to avoid division by zero\n",
    "    print(y_true.shape)\n",
    "    print(y_pred.shape)\n",
    "    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])\n",
    "    print(intersection)\n",
    "    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])\n",
    "    dice_coefficient = (2.0 * intersection + smooth) / (union + smooth)\n",
    "    loss = 1.0 - tf.reduce_mean(dice_coefficient)\n",
    "    return loss"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2D Dense UNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import Model\n",
    "from keras.layers import Input, Concatenate, Conv2D, MaxPooling2D, UpSampling2D, Dropout\n",
    "\n",
    "def dense_block(x, filters, dropout_rate):\n",
    "    conv1 = Conv2D(filters, 3, activation='relu', padding='same')(x)\n",
    "    conv2 = Conv2D(filters, 3, activation='relu', padding='same')(conv1)\n",
    "    concat = Concatenate(axis=-1)([x, conv2])\n",
    "    if dropout_rate:\n",
    "        concat = Dropout(dropout_rate)(concat)\n",
    "    return concat\n",
    "\n",
    "def dense_unet(input_shape, num_classes, filters=16, num_layers=4, dropout_rate=0.2):\n",
    "    inputs = Input(input_shape)\n",
    "    skip_connections = []\n",
    "\n",
    "    # Downward path\n",
    "    x = inputs\n",
    "    for _ in range(num_layers):\n",
    "        x = dense_block(x, filters, dropout_rate)\n",
    "        skip_connections.append(x)\n",
    "        x = MaxPooling2D(pool_size=(2, 2))(x)\n",
    "        filters *= 2\n",
    "\n",
    "    # Bridge\n",
    "    x = dense_block(x, filters, dropout_rate)\n",
    "\n",
    "    # Upward path\n",
    "    for i in range(num_layers):\n",
    "        filters //= 2\n",
    "        x = UpSampling2D(size=(2, 2))(x)\n",
    "        x = Concatenate(axis=-1)([x, skip_connections.pop()])\n",
    "        x = dense_block(x, filters, dropout_rate)\n",
    "\n",
    "    # Output layer\n",
    "    outputs = Conv2D(num_classes, 1, activation='sigmoid')(x)\n",
    "\n",
    "    model = Model(inputs=inputs, outputs=outputs)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "H-DenseUNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from keras.models import Model\n",
    "from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate\n",
    "\n",
    "def dense_block(x, filters, num_layers):\n",
    "    for _ in range(num_layers):\n",
    "        conv = Conv2D(filters=filters, kernel_size=(3, 3), activation='relu', padding='same')(x)\n",
    "        x = concatenate([x, conv], axis=-1)\n",
    "    return x\n",
    "\n",
    "def transition_down(x, filters):\n",
    "    x = Conv2D(filters=filters, kernel_size=(1, 1), activation='relu', padding='same')(x)\n",
    "    x = MaxPooling2D(pool_size=(2, 2))(x)\n",
    "    return x\n",
    "\n",
    "def transition_up(x, filters):\n",
    "    x = Conv2DTranspose(filters=filters, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)\n",
    "    return x\n",
    "\n",
    "def h_dense_unet(input_shape, num_classes):\n",
    "    inputs = Input(input_shape)\n",
    "\n",
    "    # Initial Convolution Block\n",
    "    conv1 = Conv2D(filters=64, kernel_size=(3, 3), activation='relu', padding='same')(inputs)\n",
    "\n",
    "    # Downsample path\n",
    "    down1 = dense_block(conv1, filters=64, num_layers=4)\n",
    "    pool1 = transition_down(down1, filters=128)\n",
    "\n",
    "    down2 = dense_block(pool1, filters=128, num_layers=4)\n",
    "    pool2 = transition_down(down2, filters=256)\n",
    "\n",
    "    down3 = dense_block(pool2, filters=256, num_layers=4)\n",
    "    pool3 = transition_down(down3, filters=512)\n",
    "\n",
    "    down4 = dense_block(pool3, filters=512, num_layers=4)\n",
    "\n",
    "    # Upsample path\n",
    "    up4 = transition_up(down4, filters=256)\n",
    "    up4 = concatenate([up4, down3], axis=-1)\n",
    "    up4 = dense_block(up4, filters=256, num_layers=4)\n",
    "\n",
    "    up3 = transition_up(up4, filters=128)\n",
    "    up3 = concatenate([up3, down2], axis=-1)\n",
    "    up3 = dense_block(up3, filters=128, num_layers=4)\n",
    "\n",
    "    up2 = transition_up(up3, filters=64)\n",
    "    up2 = concatenate([up2, down1], axis=-1)\n",
    "    up2 = dense_block(up2, filters=64, num_layers=4)\n",
    "\n",
    "    # Output\n",
    "    outputs = Conv2D(filters=num_classes, kernel_size=(1, 1), activation='softmax')(up2)\n",
    "\n",
    "    model = Model(inputs=inputs, outputs=outputs)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "nnUNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Importing Preprocessed Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "Importing Scan & Mask Data...\n",
      "------------------------------\n",
      "imgs_train size:  (214, 256, 256, 1)\n",
      "imgs_mask_train size:  (214, 256, 256, 1)\n",
      "------------------------------\n",
      "Importing Exporting Stage Completed!\n",
      "------------------------------\n"
     ]
    }
   ],
   "source": [
    "from writeout_dataset import ReadInDatasets, ReadInDatasetNPY\n",
    "\n",
    "print('-'*30)\n",
    "print('Importing Scan & Mask Data...')\n",
    "print('-'*30)\n",
    "\n",
    "# imgs_train = ReadInDatasetNPY('/content/drive/MyDrive/Colab Notebooks/Model Input Scans Tibia (Subset)')\n",
    "# imgs_mask_train = ReadInDatasetNPY('/content/drive/MyDrive/Colab Notebooks/Model Input Masks Tibia (Subset)')\n",
    "\n",
    "# imgs_train = ReadInDatasetNPY('/content/drive/MyDrive/Colab Notebooks/Model Input Scans Tibia (Subset)')\n",
    "# imgs_mask_train = ReadInDatasetNPY('/content/drive/MyDrive/Colab Notebooks/Model Input Masks Tibia (Subset)')\n",
    "\n",
    "imgs_train = ReadInDatasetNPY('D:/MRI - Tairawhiti/Model Input Scans Tibia (Subset)')\n",
    "imgs_mask_train = ReadInDatasetNPY('D:/MRI - Tairawhiti/Model Input Masks Tibia (Subset)')\n",
    "\n",
    "print('imgs_train size: ', imgs_train.shape)\n",
    "print('imgs_mask_train size: ', imgs_mask_train.shape)\n",
    "\n",
    "print('-'*30)\n",
    "print('Importing Exporting Stage Completed!')\n",
    "print('-'*30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training UNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------\n",
      "Training UNet Model...\n",
      "------------------------------\n",
      "Training Image Input Shape:  (168, 400, 400, 1)\n",
      "Training Mask Input Shape:  (168, 400, 400, 1)\n",
      "Validation Image Input Shape:  (42, 400, 400, 1)\n",
      "Validation Mask Input Shape:  (42, 400, 400, 1)\n",
      "------------------------------\n",
      "Creating and compiling model...\n",
      "------------------------------\n",
      "Epoch 1/10\n"
     ]
    }
   ],
   "source": [
    "print('-'*30)\n",
    "print('Training UNet Model...')\n",
    "print('-'*30)\n",
    "\n",
    "images_train, images_val, labels_train, labels_val = train_test_split(imgs_train, imgs_mask_train, test_size=0.2, random_state=0)\n",
    "print('Training Image Input Shape: ', images_train.shape)\n",
    "print('Training Mask Input Shape: ', labels_train.shape)\n",
    "print('Validation Image Input Shape: ', images_val.shape)\n",
    "print('Validation Mask Input Shape: ', labels_val.shape)\n",
    "\n",
    "\n",
    "print('-'*30)\n",
    "print('Creating and compiling model...')\n",
    "print('-'*30)\n",
    "\n",
    "# my_callbacks = [\n",
    "#     tf.keras.callbacks.EarlyStopping(patience=PATIENCE), # early stopping\n",
    "#     tf.keras.callbacks.ModelCheckpoint(filepath=MODEL_FNAME_PATTERN, save_best_only=True) # save the best based on validation\n",
    "# ]\n",
    "\n",
    "# input_shape = (256,256,1)\n",
    "# model = unet_model(input_shape)\n",
    "# model.compile(optimizer='adam', loss='binary_crossentropy')\n",
    "# # model.compile(optimizer='adam', loss=dice_loss)\n",
    "# checkpoint = ModelCheckpoint('model(loss=binary,batch_size=4,epochs=20,train_size=500,aoi=tibia).h5', monitor='val_loss', save_best_only=True, mode='min')\n",
    "# model.fit(x=images_train, y=labels_train, batch_size=4, epochs=20, validation_data=(images_val, labels_val), callbacks=[checkpoint])\n",
    "# model.summary()\n",
    "\n",
    "# Dense 2D UNet\n",
    "# input_shape = (512,512,1)\n",
    "# num_classes = 1\n",
    "# model = dense_unet(input_shape, num_classes)\n",
    "# model.compile(optimizer='adam', loss='binary_crossentropy')\n",
    "# checkpoint = ModelCheckpoint('model(Baseline + 2D Dense Layers 125 Epoch).h5', monitor='val_loss', save_best_only=True, mode='min')\n",
    "# model.fit(x=images_train, y=labels_train, batch_size=1, epochs=125, validation_data=(images_val, labels_val), callbacks=[checkpoint])\n",
    "# model.summary()\n",
    "\n",
    "\n",
    "# H-DenseUNet\n",
    "input_shape = (400,400,1)\n",
    "num_classes = 1\n",
    "batch_size = 2\n",
    "num_epochs = 10\n",
    "model = h_dense_unet(input_shape, num_classes)\n",
    "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
    "model.fit(images_train, labels_train, validation_data=(images_val, labels_val), batch_size=batch_size, epochs=num_epochs)\n",
    "model.summary()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prediction Mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "imgs_test size (Scans):  (214, 400, 400, 1)\n",
      "imgs_test_masks size (Masks):  (214, 400, 400, 1)\n",
      "------------------------------\n",
      "Prediction Made Using Weights From Model: D:/MRI - Tairawhiti/Trained Models/model(2D-UNet, 3 Patients, Tibia, 400x400x1, 60 epoches).h5\n",
      "------------------------------\n",
      "WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x0000023044484EE0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n",
      "1/1 [==============================] - 3s 3s/step\n",
      "Testing Image Input Shape:  (400, 400, 1)\n",
      "\n",
      "\n",
      "Dice Similarity Coefficient (DSC) Metric Value for Specified Slice:  [0.15360412]\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# model = 'model(Baseline 125 Epoch).h5'\n",
    "# model = 'D:/MRI - Tairawhiti/Trained Models/model(2D UNet, 9 Patients, Tibia).h5'\n",
    "# model = 'D:/MRI - Tairawhiti/Trained Models/model(2D UNet, 9 Patients, Tibia).h5'\n",
    "# model = 'D:/MRI - Tairawhiti/Trained Models/H-DenseUNet.h5'\n",
    "model = 'D:/MRI - Tairawhiti/Trained Models/model(2D-UNet, 3 Patients, Tibia, 400x400x1, 60 epoches).h5'\n",
    "pred_img_idx = 50\n",
    "MaskStack2D = False\n",
    "\n",
    "from writeout_dataset import ReadInDatasets, ReadInDatasetNPY\n",
    "# imgs_test = ReadInDatasets('/content/drive/MyDrive/Colab Notebooks/Test_Data_Tibia(Collab Sample)', 0, 10)\n",
    "# imgs_test_masks = ReadInDatasets('/content/drive/MyDrive/Colab Notebooks/Test_Masks_Tibia(Collab Sample)', 0, 10)\n",
    "imgs_test = ReadInDatasetNPY('D:/MRI - Tairawhiti/Test_Data_Tibia(Collab Sample)')\n",
    "imgs_test_masks = ReadInDatasetNPY('D:/MRI - Tairawhiti/Test_Masks_Tibia(Collab Sample)')\n",
    "\n",
    "# imgs_test_masks = (imgs_test_masks > 0).astype(np.uint8)\n",
    "# testing_scans_processed = np.reshape(imgs_test, (len(imgs_test), 512, 512, 1))\n",
    "# testing_masks_processed = np.reshape(imgs_test_masks, (len(imgs_test), 512, 512, 1))\n",
    "print('imgs_test size (Scans): ', imgs_test.shape)\n",
    "print('imgs_test_masks size (Masks): ', imgs_test_masks.shape)\n",
    "\n",
    "def superimpose_images(image1, image2):\n",
    "    # Normalize the image intensities\n",
    "    image1 = image1 / np.max(image1)\n",
    "    image2 = image2 / np.max(image2)\n",
    "\n",
    "    alpha = 0.5  # Opacity of raw scan\n",
    "    superimposed_image = alpha * image1 + (1 - alpha) * image2\n",
    "    return superimposed_image\n",
    "\n",
    "def dice_coefficient(y_true, y_pred):\n",
    "    intersection = np.sum(y_true * y_pred, axis=(1, 2, 3))\n",
    "    union = np.sum(y_true, axis=(1, 2, 3)) + np.sum(y_pred, axis=(1, 2, 3))\n",
    "    dice = (2.0 * intersection) / (union + 1e-7)  # Adding a small epsilon to avoid division by zero\n",
    "    return dice\n",
    "\n",
    "print('-'*30)\n",
    "print(f'Prediction Made Using Weights From Model: {model}')\n",
    "print('-'*30)\n",
    "\n",
    "# Prediction\n",
    "best_model = load_model(model)\n",
    "prediction = best_model.predict(np.reshape(imgs_test[pred_img_idx], (1,imgs_test.shape[1],imgs_test.shape[2],1)))\n",
    "print('Testing Image Input Shape: ',imgs_test[pred_img_idx].shape)\n",
    "# print('Prediction Mask Shape: ', prediction.shape)\n",
    "print('\\n')\n",
    "\n",
    "# Evaluation\n",
    "rounded_array  = np.round(prediction, decimals=3)\n",
    "binary_pred = np.where(rounded_array != 0, 1, 0)\n",
    "DSC = dice_coefficient((np.reshape(imgs_test_masks[pred_img_idx], (1,imgs_test_masks.shape[1],imgs_test_masks.shape[2],1))), binary_pred)\n",
    "print('Dice Similarity Coefficient (DSC) Metric Value for Specified Slice: ', DSC)\n",
    "\n",
    "if (MaskStack2D == True):\n",
    "  DSC_stack = []\n",
    "  pred_stack = []\n",
    "  for i in tqdm(range(len(imgs_test_masks)), desc=\"Evaluating DSC on Paitent Scan Stack\"):\n",
    "    prediction_patient = best_model.predict(np.reshape(imgs_test[i], (1,imgs_test.shape[1],imgs_test.shape[2],1)))\n",
    "    rounded_array  = np.round(prediction_patient, decimals=3)\n",
    "    binary_pred_patient = np.where(rounded_array != 0, 1, 0)\n",
    "    pred_stack.append(binary_pred_patient)\n",
    "    DSC_patient = dice_coefficient((np.reshape(imgs_test_masks[i], (1,imgs_test_masks.shape[1],imgs_test_masks.shape[2],1))), binary_pred_patient)\n",
    "    DSC_stack.append(DSC_patient)\n",
    "  print('Average Dice Similarity Coefficient (DSC) Metric Value for Patient Scan Stack: ', np.mean(DSC_stack))\n",
    "  np.save('ManualSegStack.npy', DSC_stack)\n",
    "  print('\\n')\n",
    "\n",
    "# Visualisations\n",
    "cmap_binary = 'white'\n",
    "cmap_segmask = plt.cm.colors.ListedColormap(['black', cmap_binary])\n",
    "bounds = [0, 0.5, 1]\n",
    "norm = plt.cm.colors.BoundaryNorm(bounds, cmap_segmask.N)\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(binary_pred[0, :, :, 0], cmap=cmap_segmask, norm=norm)\n",
    "ax.axis('on')\n",
    "plt.title('Model Outputted Segmentation Mask')\n",
    "plt.show()\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(imgs_test_masks[pred_img_idx, :, :, 0], cmap='gray')\n",
    "ax.axis('on')\n",
    "plt.title('Groundtruth Mask')\n",
    "plt.show()\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(imgs_test[pred_img_idx, :, :, 0], cmap='gray')\n",
    "ax.axis('on')\n",
    "plt.title('Raw Test Scan')\n",
    "plt.show()\n",
    "\n",
    "cmap_binary = 'YlOrBr'\n",
    "superimposed_image = superimpose_images(imgs_test_masks[pred_img_idx, :, :, 0], binary_pred[0, :, :, 0])\n",
    "# Define the cropping ranges\n",
    "x_start, x_end = 200, 400\n",
    "y_start, y_end = 200, 400\n",
    "cropped_image = superimposed_image[y_start:y_end, x_start:x_end]\n",
    "fig, ax = plt.subplots()\n",
    "cmap_superimposed = plt.cm.get_cmap(cmap_binary)\n",
    "ax.imshow(cropped_image, cmap=cmap_superimposed, vmin=0, vmax=1)\n",
    "ax.axis('on')\n",
    "plt.title('Superimposed Prediction Mask (Orange) & Groundtruth Mask (Black)')\n",
    "plt.show()\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "superimposed_image = superimpose_images(imgs_test[pred_img_idx, :, :, 0], imgs_test_masks[pred_img_idx, :, :, 0])\n",
    "ax.imshow(superimposed_image, cmap='gray')\n",
    "ax.axis('on')\n",
    "plt.title('Superimposed Raw Test Scan & Groundtruth Mask (Manually Segmented)')\n",
    "plt.show()\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "superimposed_image = superimpose_images(imgs_test[pred_img_idx, :, :, 0], binary_pred[0, :, :, 0])\n",
    "ax.imshow(superimposed_image, cmap='gray')\n",
    "ax.axis('on')\n",
    "plt.title('Superimposed Raw Test Scan & Prediction Mask (Automatically Segmented)')\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluation Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dice Similarity Coefficient (DSC) Metric Value:  [4.0989216e-06]\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def dice_coefficient(y_true, y_pred):\n",
    "    intersection = np.sum(y_true * y_pred, axis=(1, 2, 3))\n",
    "    union = np.sum(y_true, axis=(1, 2, 3)) + np.sum(y_pred, axis=(1, 2, 3))\n",
    "    dice = (2.0 * intersection) / (union + 1e-7)  # Adding a small epsilon to avoid division by zero\n",
    "    return dice\n",
    "\n",
    "def assd(y_true, y_pred, spacing):\n",
    "    surface_distances = surface_distance(y_true, y_pred, spacing)\n",
    "    avg_surface_distance = np.mean(surface_distances)\n",
    "    return avg_surface_distance\n",
    "\n",
    "def surface_distance(y_true, y_pred, spacing):\n",
    "    true_surface = find_surface_points(y_true, spacing)\n",
    "    pred_surface = find_surface_points(y_pred, spacing)\n",
    "\n",
    "    if true_surface.shape[0] == 0 or pred_surface.shape[0] == 0:\n",
    "        raise ValueError(\"One or both surface point arrays are empty.\")\n",
    "\n",
    "    try:\n",
    "        surface_distances_true_to_pred = directed_hausdorff(true_surface, pred_surface)[0]\n",
    "        surface_distances_pred_to_true = directed_hausdorff(pred_surface, true_surface)[0]\n",
    "        surface_distances = np.concatenate([surface_distances_true_to_pred, surface_distances_pred_to_true])\n",
    "    except ValueError as e:\n",
    "        print(\"Error occurred during Hausdorff distance calculation:\", e)\n",
    "        raise\n",
    "\n",
    "    return surface_distances\n",
    "\n",
    "def find_surface_points(mask, spacing):\n",
    "    mask_padded = np.pad(mask, 1, mode='constant')\n",
    "    mask_padded_diff = np.diff(mask_padded.astype(int), axis=0)\n",
    "\n",
    "    surface_points = []\n",
    "    for z in range(mask_padded_diff.shape[0]):\n",
    "        surface_indices = np.where(mask_padded_diff[z] != 0)\n",
    "        if len(surface_indices[0]) > 0:\n",
    "            surface_points.extend(list(zip(surface_indices[0], surface_indices[1])))\n",
    "\n",
    "    surface_points = np.array(surface_points)\n",
    "    surface_points_phys = surface_points * spacing\n",
    "\n",
    "    return surface_points_phys\n",
    "\n",
    "def volume_error(y_true, y_pred):\n",
    "    true_volume = np.sum(y_true)\n",
    "    pred_volume = np.sum(y_pred)\n",
    "    volume_error = np.abs(true_volume - pred_volume) / true_volume\n",
    "    return volume_error\n",
    "\n",
    "predictions_single_scan_binarized = np.reshape(predictions_single_scan_binarized,(1, 512, 512, 1))\n",
    "DSC = dice_coefficient(testing_masks_processed, predictions_single_scan_binarized)\n",
    "print('Dice Similarity Coefficient (DSC) Metric Value: ', DSC)\n",
    "print('\\n')\n",
    "\n",
    "# VError = volume_error(testing_masks_processed, prediction)\n",
    "# print('Volume Error (VError) Metric Value: ', VError)\n",
    "# print('\\n')\n",
    "\n",
    "# spacing = 1\n",
    "# ASSD = assd(testing_masks_processed, prediction, spacing)\n",
    "# print('Average Symmetric Surface Distance (ASSD) Metric Value: ', ASSD)\n",
    "# print('\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}