[2ed9d9]: / Question 4.ipynb

Download this file

1281 lines (1280 with data), 195.8 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "70258a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d9aa10f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d6f4d17b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def walk_through_dir(dir_path):\n",
    "    \"\"\"\n",
    "    Walks through dir_path returning its contents.\n",
    "    Args:\n",
    "        dir_path (str or pathlib.Path): target directory\n",
    "    \n",
    "    Returns:\n",
    "        A print out of:\n",
    "            number of subdirectories in dir_path\n",
    "            number of images (files) in each subdirectory\n",
    "            name of each subdirectory\n",
    "    \"\"\"\n",
    "    for dirpath,dirnames,filenames in os.walk(dir_path):\n",
    "        print(f\"There are {len(dirnames)} directories and {len(filenames)} images in {dirpath}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b90146b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import zipfile\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "80dedd07",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = Path(r\"C:\\Users\\mohit\\Downloads\\DIP Assignment 2\\Data_Q4\")\n",
    "image_path = data_path / \"CXR_png\"\n",
    "masks_path = data_path / \"masks\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ffa0dce5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 0 directories and 801 images in C:\\Users\\mohit\\Downloads\\DIP Assignment 2\\Data_Q4\\CXR_png\n"
     ]
    }
   ],
   "source": [
    "walk_through_dir(image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "30a03ec4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(WindowsPath('C:/Users/mohit/Downloads/DIP Assignment 2/Data_Q4/CXR_png/train'),\n",
       " WindowsPath('C:/Users/mohit/Downloads/DIP Assignment 2/Data_Q4/CXR_png/test'))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#setup train and testing paths\n",
    "train_dir = image_path / \"train\"\n",
    "test_dir = image_path / \"test\"\n",
    "\n",
    "train_dir, test_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ebb14896",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6434f0cf",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "list index out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[22], line 8\u001b[0m\n\u001b[0;32m      5\u001b[0m image_path_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(image_path\u001b[38;5;241m.\u001b[39mglob(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m*/*.png\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m      7\u001b[0m \u001b[38;5;66;03m#Get random image path\u001b[39;00m\n\u001b[1;32m----> 8\u001b[0m random_image_path \u001b[38;5;241m=\u001b[39m \u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mchoice\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage_path_list\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     10\u001b[0m \u001b[38;5;66;03m#Get image class from path name (name of directory)\u001b[39;00m\n\u001b[0;32m     11\u001b[0m image_class \u001b[38;5;241m=\u001b[39m random_image_path\u001b[38;5;241m.\u001b[39mparent\u001b[38;5;241m.\u001b[39mstem\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\random.py:346\u001b[0m, in \u001b[0;36mRandom.choice\u001b[1;34m(self, seq)\u001b[0m\n\u001b[0;32m    344\u001b[0m \u001b[38;5;124;03m\"\"\"Choose a random element from a non-empty sequence.\"\"\"\u001b[39;00m\n\u001b[0;32m    345\u001b[0m \u001b[38;5;66;03m# raises IndexError if seq is empty\u001b[39;00m\n\u001b[1;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mseq\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_randbelow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mseq\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n",
      "\u001b[1;31mIndexError\u001b[0m: list index out of range"
     ]
    }
   ],
   "source": [
    "#set seed\n",
    "random.seed(42)\n",
    "\n",
    "#Get all image paths (* means 'any combination')\n",
    "image_path_list = list(image_path.glob(\"*/*.png\"))\n",
    "\n",
    "#Get random image path\n",
    "random_image_path = random.choice(image_path_list)\n",
    "\n",
    "#Get image class from path name (name of directory)\n",
    "image_class = random_image_path.parent.stem\n",
    "\n",
    "#Open image\n",
    "img = Image.open(random_image_path)\n",
    "\n",
    "#Print metadata\n",
    "print(f\"Random Image Path: {random_image_path}\")\n",
    "print(f\"Image class: {image_class}\")\n",
    "print(f\"Image height: {img.height}\")\n",
    "print(f\"Image width: {img.width}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "333ee09f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mohit\\anaconda3\\envs\\tf-new\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "#Transforming data into tensors, then into torch.utils.data.Dataset\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets,transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "00ff772c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_transform = transforms.Compose([transforms.ToPILImage(),\n",
    "    #Resize the images to 572x572\n",
    "    transforms.Resize(size = (128,128)),\n",
    "    #turn the imahe into torch.Tensor\n",
    "    transforms.ToTensor() #this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c887a0fb",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'image_path_list' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[11], line 33\u001b[0m\n\u001b[0;32m     29\u001b[0m             ax[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m     31\u001b[0m             fig\u001b[38;5;241m.\u001b[39msuptitle(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mClass: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mimage_path\u001b[38;5;241m.\u001b[39mparent\u001b[38;5;241m.\u001b[39mstem\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m , fontsize \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m16\u001b[39m)\n\u001b[1;32m---> 33\u001b[0m plot_transformed_images(\u001b[43mimage_path_list\u001b[49m,transform \u001b[38;5;241m=\u001b[39m data_transform, n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m)\n",
      "\u001b[1;31mNameError\u001b[0m: name 'image_path_list' is not defined"
     ]
    }
   ],
   "source": [
    "def plot_transformed_images(image_paths, transform,n=3,seed=42):\n",
    "    \"\"\"Plots a series of random images from image_paths.\n",
    "    \n",
    "    Will open n image from image_paths, transform them with transform and\n",
    "    plot them side by side\n",
    "    \n",
    "    Args:\n",
    "        image_path (list): List of target image paths.\n",
    "        transform (PyTorch Transforms): Transforms to apply to images.\n",
    "        n (int,optional): Number of images to plot. Defaults to 3.\n",
    "        seed (int, optional): Random seed for the random generator. Defaults to 42\n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    random_image_paths = random.sample(image_paths,k=n)\n",
    "    for image_path in random_image_paths:\n",
    "        with Image.open(image_path) as f:\n",
    "            fig, ax = plt.subplots(1,2)\n",
    "            ax[0].imshow(f)\n",
    "            ax[0].set_title(f\"Original \\nSize: {f.size}\")\n",
    "            ax[0].axis('off')\n",
    "            \n",
    "            #transform and plot the image\n",
    "            #note: permute changes shape of image to suit matplotlib\n",
    "            #Pytorch default is [C,H,W] but Matplotlib is [H,W,C]\n",
    "            tensor_image = transforms.ToTensor()(f)\n",
    "            transformed_image = transform(tensor_image).permute(1,2,0)\n",
    "            ax[1].imshow(transformed_image)\n",
    "            ax[1].set_title(f\"Transformed \\nSize: {transformed_image.shape}\")\n",
    "            ax[1].axis('off')\n",
    "            \n",
    "            fig.suptitle(f'Class: {image_path.parent.stem}' , fontsize = 16)\n",
    "            \n",
    "plot_transformed_images(image_path_list,transform = data_transform, n=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8649a13c",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "Couldn't find any class folder in C:\\Users\\mohit\\Downloads\\DIP Assignment 2\\Q4_Data\\CXR_png\\train.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[17], line 3\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;66;03m#Using ImageFolder to create datasets\u001b[39;00m\n\u001b[0;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchvision\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m datasets\n\u001b[1;32m----> 3\u001b[0m train_data \u001b[38;5;241m=\u001b[39m \u001b[43mdatasets\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mImageFolder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtrain_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m      4\u001b[0m \u001b[43m                                 \u001b[49m\u001b[43mtransform\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdata_transform\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m      5\u001b[0m \u001b[43m                                 \u001b[49m\u001b[43mtarget_transform\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m      6\u001b[0m test_data \u001b[38;5;241m=\u001b[39m datasets\u001b[38;5;241m.\u001b[39mImageFolder(root \u001b[38;5;241m=\u001b[39m test_dir,\n\u001b[0;32m      7\u001b[0m                                 transform \u001b[38;5;241m=\u001b[39m data_transform)\n\u001b[0;32m      8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTrain data:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mtrain_data\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mTest data:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mtest_data\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\site-packages\\torchvision\\datasets\\folder.py:310\u001b[0m, in \u001b[0;36mImageFolder.__init__\u001b[1;34m(self, root, transform, target_transform, loader, is_valid_file)\u001b[0m\n\u001b[0;32m    302\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\n\u001b[0;32m    303\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m    304\u001b[0m     root: \u001b[38;5;28mstr\u001b[39m,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    308\u001b[0m     is_valid_file: Optional[Callable[[\u001b[38;5;28mstr\u001b[39m], \u001b[38;5;28mbool\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m    309\u001b[0m ):\n\u001b[1;32m--> 310\u001b[0m     \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[0;32m    311\u001b[0m \u001b[43m        \u001b[49m\u001b[43mroot\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    312\u001b[0m \u001b[43m        \u001b[49m\u001b[43mloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    313\u001b[0m \u001b[43m        \u001b[49m\u001b[43mIMG_EXTENSIONS\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mis_valid_file\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    314\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    315\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtarget_transform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget_transform\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    316\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_valid_file\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_valid_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    317\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    318\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimgs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msamples\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\site-packages\\torchvision\\datasets\\folder.py:145\u001b[0m, in \u001b[0;36mDatasetFolder.__init__\u001b[1;34m(self, root, loader, extensions, transform, target_transform, is_valid_file)\u001b[0m\n\u001b[0;32m    135\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\n\u001b[0;32m    136\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m    137\u001b[0m     root: \u001b[38;5;28mstr\u001b[39m,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    142\u001b[0m     is_valid_file: Optional[Callable[[\u001b[38;5;28mstr\u001b[39m], \u001b[38;5;28mbool\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m    143\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    144\u001b[0m     \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(root, transform\u001b[38;5;241m=\u001b[39mtransform, target_transform\u001b[38;5;241m=\u001b[39mtarget_transform)\n\u001b[1;32m--> 145\u001b[0m     classes, class_to_idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfind_classes\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mroot\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    146\u001b[0m     samples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmake_dataset(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mroot, class_to_idx, extensions, is_valid_file)\n\u001b[0;32m    148\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloader \u001b[38;5;241m=\u001b[39m loader\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\site-packages\\torchvision\\datasets\\folder.py:219\u001b[0m, in \u001b[0;36mDatasetFolder.find_classes\u001b[1;34m(self, directory)\u001b[0m\n\u001b[0;32m    192\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfind_classes\u001b[39m(\u001b[38;5;28mself\u001b[39m, directory: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[List[\u001b[38;5;28mstr\u001b[39m], Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mint\u001b[39m]]:\n\u001b[0;32m    193\u001b[0m     \u001b[38;5;124;03m\"\"\"Find the class folders in a dataset structured as follows::\u001b[39;00m\n\u001b[0;32m    194\u001b[0m \n\u001b[0;32m    195\u001b[0m \u001b[38;5;124;03m        directory/\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    217\u001b[0m \u001b[38;5;124;03m        (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.\u001b[39;00m\n\u001b[0;32m    218\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m--> 219\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfind_classes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdirectory\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\site-packages\\torchvision\\datasets\\folder.py:43\u001b[0m, in \u001b[0;36mfind_classes\u001b[1;34m(directory)\u001b[0m\n\u001b[0;32m     41\u001b[0m classes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msorted\u001b[39m(entry\u001b[38;5;241m.\u001b[39mname \u001b[38;5;28;01mfor\u001b[39;00m entry \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39mscandir(directory) \u001b[38;5;28;01mif\u001b[39;00m entry\u001b[38;5;241m.\u001b[39mis_dir())\n\u001b[0;32m     42\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m classes:\n\u001b[1;32m---> 43\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find any class folder in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdirectory\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     45\u001b[0m class_to_idx \u001b[38;5;241m=\u001b[39m {cls_name: i \u001b[38;5;28;01mfor\u001b[39;00m i, cls_name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(classes)}\n\u001b[0;32m     46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m classes, class_to_idx\n",
      "\u001b[1;31mFileNotFoundError\u001b[0m: Couldn't find any class folder in C:\\Users\\mohit\\Downloads\\DIP Assignment 2\\Q4_Data\\CXR_png\\train."
     ]
    }
   ],
   "source": [
    "#Using ImageFolder to create datasets\n",
    "from torchvision import datasets\n",
    "train_data = datasets.ImageFolder(root = train_dir,\n",
    "                                 transform = data_transform,\n",
    "                                 target_transform = None)\n",
    "test_data = datasets.ImageFolder(root = test_dir,\n",
    "                                transform = data_transform)\n",
    "print(f'Train data:\\n{train_data}\\nTest data:\\n{test_data}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "fc1cdf5b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['xray']"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8dfb8194",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'xray': 0}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.class_to_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "039dc47c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#defining U-Net parameters\n",
    "NUM_CHANNELS = 1\n",
    "NUM_CLASSES = 1\n",
    "NUM_LEVELS = 3\n",
    "\n",
    "#initialize learning rate, number of epochs to train for and batch size\n",
    "INIT_LR = 0.001\n",
    "NUM_EPOCHS = 15\n",
    "BATCH_SIZE = 64\n",
    "\n",
    "#define the input image dimensions\n",
    "INPUT_IMAGE_WIDTH = 128\n",
    "INPUT_IMAGE_HEIGHT = 128\n",
    "\n",
    "#threshold to filter weak predictions\n",
    "THRESHOLD = 0.5\n",
    "\n",
    "#define path to the base output directory\n",
    "BASE_OUTPUT = \"output\"\n",
    "\n",
    "MODEL_PATH = os.path.join(BASE_OUTPUT,\"unet_xray.pth\")\n",
    "PLOT_PATH = os.path.sep.join([BASE_OUTPUT,\"plot.png\"])\n",
    "TEST_PATHS = os.path.sep.join([BASE_OUTPUT,\"test_paths.txt\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "92babcff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "class SegmentationDataset(Dataset):\n",
    "    def __init__(self,imagePaths,maskPaths,transforms):\n",
    "        self.imagePaths = imagePaths\n",
    "        self.maskPaths = maskPaths\n",
    "        self.transforms = transforms\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.imagePaths)\n",
    "    def __getitem__(self,idx):\n",
    "        imagePath = self.imagePaths[idx]\n",
    "        \n",
    "        image = cv2.imread(imagePath)\n",
    "        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)\n",
    "        mask = cv2.imread(self.maskPaths[idx],0)\n",
    "        \n",
    "        if self.transforms is not None:\n",
    "            image = self.transforms(image)\n",
    "            mask = self.transforms(mask)\n",
    "        \n",
    "        return (image,mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "40f1860b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import ConvTranspose2d\n",
    "from torch.nn import Conv2d\n",
    "from torch.nn import MaxPool2d\n",
    "from torch.nn import Module\n",
    "from torch.nn import ModuleList\n",
    "from torch.nn import ReLU\n",
    "from torchvision.transforms import CenterCrop\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "a7fe0b54",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "1d998ca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(Module):\n",
    "    def __init__(self,inChannels,outChannels):\n",
    "        super().__init__()\n",
    "        self.conv1 = Conv2d(inChannels,outChannels,3)\n",
    "        self.relu = ReLU()\n",
    "        self.conv2 = Conv2d(outChannels,outChannels,3)\n",
    "    \n",
    "    def forward(self,x):\n",
    "        return self.conv2(self.relu(self.conv1(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "339b67ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(Module):\n",
    "    def __init__(self,channels = (3,16,32,64)):\n",
    "        super().__init__()\n",
    "        self.encBlocks = ModuleList([Block(\n",
    "        channels[i],channels[i+1]) for i in range(len(channels)-1)])\n",
    "        self.pool = MaxPool2d(2)\n",
    "    \n",
    "    def forward(self,x):\n",
    "        blockOutputs = []\n",
    "        for block in self.encBlocks:\n",
    "            x = block(x)\n",
    "            blockOutputs.append(x)\n",
    "            x = self.pool(x)\n",
    "        return blockOutputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "088c79b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(Module):\n",
    "\tdef __init__(self, channels=(64, 32, 16)):\n",
    "\t\tsuper().__init__()\n",
    "\t\t# initialize the number of channels, upsampler blocks, and\n",
    "\t\t# decoder blocks\n",
    "\t\tself.channels = channels\n",
    "\t\tself.upconvs = ModuleList([ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)])\n",
    "\t\tself.dec_blocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])\n",
    "        \n",
    "\tdef forward(self, x, encFeatures):\n",
    "\t\t# loop through the number of channels\n",
    "\t\tfor i in range(len(self.channels) - 1):\n",
    "\t\t\t# pass the inputs through the upsampler blocks\n",
    "\t\t\tx = self.upconvs[i](x)\n",
    "\t\t\t# crop the current features from the encoder blocks,\n",
    "\t\t\t# concatenate them with the current upsampled features,\n",
    "\t\t\t# and pass the concatenated output through the current\n",
    "\t\t\t# decoder block\n",
    "\t\t\tencFeat = self.crop(encFeatures[i], x)\n",
    "\t\t\tx = torch.cat([x, encFeat], dim=1)\n",
    "\t\t\tx = self.dec_blocks[i](x)\n",
    "\t\t# return the final decoder output\n",
    "\t\treturn x\n",
    "    \n",
    "\tdef crop(self, encFeatures, x):\n",
    "\t\t# grab the dimensions of the inputs, and crop the encoder\n",
    "\t\t# features to match the dimensions\n",
    "\t\t(_, _, H, W) = x.shape\n",
    "\t\tencFeatures = CenterCrop([H, W])(encFeatures)\n",
    "\t\t# return the cropped features\n",
    "\t\treturn encFeatures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "61e3b314",
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNet(Module):\n",
    "    def __init__(self,encChannels = (3,16,32,64),\n",
    "    decChannels = (64,32,16),\n",
    "    nbClasses =1,retainDim = True,\n",
    "    outSize = (INPUT_IMAGE_HEIGHT,INPUT_IMAGE_WIDTH)):\n",
    "        super().__init__()\n",
    "        self.encoder = Encoder(encChannels)\n",
    "        self.decoder = Decoder(decChannels)\n",
    "        \n",
    "        self.head = Conv2d(decChannels[-1],nbClasses,1)\n",
    "        self.retainDim = retainDim\n",
    "        self.outSize = outSize\n",
    "        \n",
    "    def forward(self,x):\n",
    "        encFeatures = self.encoder(x)\n",
    "        decFeatures = self.decoder(encFeatures[::-1][0],\n",
    "                                  encFeatures[::-1][1:])\n",
    "        map = self.head(decFeatures)\n",
    "        \n",
    "        if self.retainDim:\n",
    "            map = F.interpolate(map,self.outSize)\n",
    "            \n",
    "        return map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "4aa75c19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import BCEWithLogitsLoss\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torchvision import transforms\n",
    "from imutils import paths\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import time\n",
    "import os\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "96d60fea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] saving testing image paths\n"
     ]
    }
   ],
   "source": [
    "imagePaths = sorted(list(paths.list_images(image_path)))\n",
    "maskPaths = sorted(list(paths.list_images(masks_path)))\n",
    "\n",
    "split = train_test_split(imagePaths,maskPaths,test_size = 0.25,random_state =42)\n",
    "\n",
    "(trainImages,testImages) = split[:2]\n",
    "(trainMasks,testMasks) = split[2:]\n",
    "\n",
    "print(\"[INFO] saving testing image paths\")\n",
    "f = open(test_dir,\"w\")\n",
    "f.write(\"\\n\".join(testImages))\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "3ca38b1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainDS = SegmentationDataset(imagePaths = trainImages, \n",
    "                              maskPaths = trainMasks, \n",
    "                              transforms = data_transform)\n",
    "testDS = SegmentationDataset(imagePaths = testImages,\n",
    "                              maskPaths = testMasks,\n",
    "                              transforms = data_transform)\n",
    "\n",
    "trainLoader = DataLoader(trainDS, shuffle = True, \n",
    "                        batch_size = BATCH_SIZE)\n",
    "testLoader = DataLoader(testDS, shuffle = False, batch_size = BATCH_SIZE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "178a8935",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(trainDS[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "3712b76e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 128, 128])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainDS[0][0].size()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb098ca5",
   "metadata": {},
   "source": [
    "trainDS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "07a9f37d",
   "metadata": {},
   "outputs": [],
   "source": [
    "unet = UNet()\n",
    "lossfn = BCEWithLogitsLoss()\n",
    "optimizer  = Adam(unet.parameters(),lr = INIT_LR)\n",
    "\n",
    "trainSteps = len(trainDS) // BATCH_SIZE\n",
    "testSteps = len(testDS) // BATCH_SIZE\n",
    "\n",
    "H = {\"train_loss\" :[], \"test_loss\" :[]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "5c9f168a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|█████▍                                                                            | 1/15 [02:29<34:56, 149.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 1/15\n",
      "Train loss: 0.734736,Test loss: 0.7416\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 13%|██████████▉                                                                       | 2/15 [05:03<32:57, 152.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 2/15\n",
      "Train loss: 0.606034,Test loss: 0.6996\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 20%|████████████████▍                                                                 | 3/15 [07:37<30:38, 153.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 3/15\n",
      "Train loss: 0.593045,Test loss: 0.6943\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 27%|█████████████████████▊                                                            | 4/15 [10:13<28:14, 154.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 4/15\n",
      "Train loss: 0.584120,Test loss: 0.6829\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 33%|███████████████████████████▎                                                      | 5/15 [12:46<25:37, 153.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 5/15\n",
      "Train loss: 0.577546,Test loss: 0.6734\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 40%|████████████████████████████████▊                                                 | 6/15 [15:20<23:03, 153.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 6/15\n",
      "Train loss: 0.564767,Test loss: 0.6624\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 47%|██████████████████████████████████████▎                                           | 7/15 [17:55<20:32, 154.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 7/15\n",
      "Train loss: 0.558556,Test loss: 0.6485\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 53%|███████████████████████████████████████████▋                                      | 8/15 [20:32<18:06, 155.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 8/15\n",
      "Train loss: 0.545131,Test loss: 0.6258\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 60%|█████████████████████████████████████████████████▏                                | 9/15 [23:11<15:38, 156.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 9/15\n",
      "Train loss: 0.531916,Test loss: 0.6104\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 67%|██████████████████████████████████████████████████████                           | 10/15 [25:49<13:04, 156.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 10/15\n",
      "Train loss: 0.532720,Test loss: 0.6120\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 73%|███████████████████████████████████████████████████████████▍                     | 11/15 [28:27<10:28, 157.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 11/15\n",
      "Train loss: 0.511205,Test loss: 0.5941\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 80%|████████████████████████████████████████████████████████████████▊                | 12/15 [30:59<07:46, 155.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 12/15\n",
      "Train loss: 0.502504,Test loss: 0.5841\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 87%|██████████████████████████████████████████████████████████████████████▏          | 13/15 [33:32<05:09, 154.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 13/15\n",
      "Train loss: 0.479238,Test loss: 0.5476\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 93%|███████████████████████████████████████████████████████████████████████████▌     | 14/15 [36:05<02:34, 154.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 14/15\n",
      "Train loss: 0.468646,Test loss: 0.5330\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 15/15 [38:44<00:00, 154.98s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH: 15/15\n",
      "Train loss: 0.458341,Test loss: 0.5264\n",
      "[INFO] total time taken to train the model : {:.2f}s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'NoneType' object has no attribute 'format'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[41], line 35\u001b[0m\n\u001b[0;32m     32\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrain loss: \u001b[39m\u001b[38;5;132;01m{:,.6f}\u001b[39;00m\u001b[38;5;124m,Test loss: \u001b[39m\u001b[38;5;132;01m{:.4f}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(avgTrainLoss,avgTestLoss))\n\u001b[0;32m     34\u001b[0m endTime \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m---> 35\u001b[0m \u001b[38;5;28;43mprint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m[INFO] total time taken to train the model : \u001b[39;49m\u001b[38;5;132;43;01m{:.2f}\u001b[39;49;00m\u001b[38;5;124;43ms\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mformat\u001b[49m(endTime \u001b[38;5;241m-\u001b[39m startTime)\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'format'"
     ]
    }
   ],
   "source": [
    "startTime = time.time()\n",
    "for e in tqdm(range(NUM_EPOCHS)):\n",
    "    unet.train()\n",
    "    totalTrainLoss = 0\n",
    "    totalTestLoss = 0\n",
    "    \n",
    "    for (i,(x,y)) in enumerate(trainLoader):\n",
    "        pred = unet(x)\n",
    "        loss = lossfn(pred,y)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        totalTrainLoss += loss\n",
    "        \n",
    "    with torch.inference_mode():\n",
    "        unet.eval()\n",
    "        \n",
    "        for(x,y) in testLoader:\n",
    "            pred = unet(x)\n",
    "            \n",
    "            totalTestLoss += lossfn(pred,y)\n",
    "        \n",
    "    avgTrainLoss = totalTrainLoss / trainSteps\n",
    "    avgTestLoss = totalTestLoss / testSteps\n",
    "    \n",
    "    H[\"train_loss\"].append(avgTrainLoss.detach().numpy())\n",
    "    H[\"test_loss\"].append(avgTestLoss.detach().numpy())\n",
    "    \n",
    "    print(\"EPOCH: {}/{}\".format(e+1,NUM_EPOCHS))\n",
    "    print(\"Train loss: {:,.6f},Test loss: {:.4f}\".format(avgTrainLoss,avgTestLoss))\n",
    "    \n",
    "endTime = time.time()\n",
    "print(\"[INFO] total time taken to train the model : {:.2f}s\").format(endTime - startTime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "1c85af3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38.74571882088979"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(endTime - startTime)/60 #minutes taken to train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "a649efe9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x24905b7c790>"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.style.use(\"ggplot\")\n",
    "plt.figure()\n",
    "plt.plot(H[\"train_loss\"],label=\"train_loss\")\n",
    "plt.plot(H[\"test_loss\"],label = \"test_loss\")\n",
    "plt.title(\"Training Loss on Dataset\")\n",
    "plt.xlabel(\"Epoch #\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend(loc=\"lower left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "88eb3a94",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'output\\\\unet_xray.pth'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[44], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43munet\u001b[49m\u001b[43m,\u001b[49m\u001b[43mMODEL_PATH\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\site-packages\\torch\\serialization.py:376\u001b[0m, in \u001b[0;36msave\u001b[1;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)\u001b[0m\n\u001b[0;32m    340\u001b[0m \u001b[38;5;124;03m\"\"\"save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)\u001b[39;00m\n\u001b[0;32m    341\u001b[0m \n\u001b[0;32m    342\u001b[0m \u001b[38;5;124;03mSaves an object to a disk file.\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    372\u001b[0m \u001b[38;5;124;03m    >>> torch.save(x, buffer)\u001b[39;00m\n\u001b[0;32m    373\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    374\u001b[0m _check_dill_version(pickle_module)\n\u001b[1;32m--> 376\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mwb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[0;32m    377\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m _use_new_zipfile_serialization:\n\u001b[0;32m    378\u001b[0m         \u001b[38;5;28;01mwith\u001b[39;00m _open_zipfile_writer(opened_file) \u001b[38;5;28;01mas\u001b[39;00m opened_zipfile:\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\site-packages\\torch\\serialization.py:230\u001b[0m, in \u001b[0;36m_open_file_like\u001b[1;34m(name_or_buffer, mode)\u001b[0m\n\u001b[0;32m    228\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[0;32m    229\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[1;32m--> 230\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    231\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    232\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\tf-new\\lib\\site-packages\\torch\\serialization.py:211\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[1;34m(self, name, mode)\u001b[0m\n\u001b[0;32m    210\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[1;32m--> 211\u001b[0m     \u001b[38;5;28msuper\u001b[39m(_open_file, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
      "\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'output\\\\unet_xray.pth'"
     ]
    }
   ],
   "source": [
    "torch.save(unet,MODEL_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0351e554",
   "metadata": {},
   "source": [
    "Making Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "61c476b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "b8d75951",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_plot(origImage,predMask):\n",
    "    figure, ax = plt.subplots(nrows = 1, ncols = 2, figsize = (10,10))\n",
    "    \n",
    "    ax[0].imshow(origImage, cmap='gray')\n",
    "    #ax[1].imshow(origMask)\n",
    "    ax[1].imshow(predMask,cmap='gray')\n",
    "    \n",
    "    ax[0].set_title(\"Image\")\n",
    "    #ax[1].set_title(\"Original Mask\")\n",
    "    ax[1].set_title(\"Predicted Mask\")\n",
    "    \n",
    "    figure.tight_layout()\n",
    "    figure.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "bb5b8a84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 128) (128, 128)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mohit\\AppData\\Local\\Temp\\ipykernel_37004\\1921240287.py:13: UserWarning: Matplotlib is currently using module://matplotlib_inline.backend_inline, which is a non-GUI backend, so cannot show the figure.\n",
      "  figure.show()\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 1000x1000 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "unet.eval()\n",
    "    \n",
    "with torch.inference_mode():\n",
    "    image = cv2.imread(r\"C:\\Users\\mohit\\Downloads\\DIP Assignment 2\\Data_Q4\\CXR_png\\CHNCXR_0035_0.png\")\n",
    "    image_real = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)\n",
    "    image_real = cv2.resize(image_real,(128,128))\n",
    "    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)\n",
    "    image = image.astype(\"float32\")/255.0\n",
    "    image = cv2.resize(image,(128,128))\n",
    "        \n",
    "    orig = image.copy()\n",
    "        \n",
    "       # filename = imagePath.split(os.path.sep)[-1]\n",
    "        #groundTruthPath = os.path.join(masks_path,filename)\n",
    "        \n",
    "        #gtMask = cv2.imread(groundTruthPath,0)\n",
    "        #gtMask = cv2.resize(gtMask,(INPUT_IMAGE_HEIGHT,INPUT_IMAGE_WIDTH))\n",
    "        \n",
    "    image = np.transpose(image,(2,0,1)) #to get it to [C,H,W] format\n",
    "    image = np.expand_dims(image,0) #model inputs 4 dimensional tensor\n",
    "    image = torch.tensor(image)\n",
    "        \n",
    "    predMask = unet(image).squeeze() #to remove extra dimension\n",
    "    predMask = torch.sigmoid(predMask)\n",
    "    predMask = predMask.numpy()\n",
    "        \n",
    "    predMask = (predMask>THRESHOLD) *255\n",
    "    predMask = predMask.astype(np.uint8)\n",
    "        \n",
    "#     prepare_plot(orig,predMask)\n",
    "    print(image_real.shape, predMask.shape)\n",
    "    for i in range(len(predMask)):\n",
    "        for j in range(len(predMask[0])):\n",
    "            if predMask[i][j] == 255:\n",
    "                image_real[i][j] = 255\n",
    "\n",
    "    prepare_plot(image_real,predMask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "0088d784",
   "metadata": {},
   "outputs": [],
   "source": [
    "image2 = cv2.imread(r\"C:\\Users\\mohit\\Downloads\\DIP Assignment 2\\Data_Q4\\CXR_png\\CHNCXR_0035_0.png\")\n",
    "image_real = cv2.cvtColor(image2,cv2.COLOR_BGR2GRAY)\n",
    "image_real = cv2.resize(image_real,(128,128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "65d46401",
   "metadata": {},
   "outputs": [],
   "source": [
    "ROI = np.multiply(predMask,image_real)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "c60edfc7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       ...,\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)"
      ]
     },
     "execution_count": 121,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ROI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "7839d01a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.5, 127.5, 127.5, -0.5)"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(ROI, cmap='gray')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50bd688",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}