2341 lines (2341 with data), 349.5 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"toc_visible": true,
"authorship_tag": "ABX9TyMkI8ahXOwRr6x9BhyBLJgu",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/stonerrb/MisaHub/blob/main/MisaHubResNet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Classifying Bleeding and Non-Bleeding Using Resnet**"
],
"metadata": {
"id": "4PkpgiAcFLLx"
}
},
{
"cell_type": "markdown",
"source": [
"# **Importing and Splitting**\n"
],
"metadata": {
"id": "r89LWZP2g9GJ"
}
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 338
},
"id": "VoJ-HvHoeo0Q",
"outputId": "fb332a65-3913-47bb-bc31-16e5a91c0cd8"
},
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-d5df0069828e>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgoogle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolab\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdrive\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mdrive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmount\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/drive'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/drive.py\u001b[0m in \u001b[0;36mmount\u001b[0;34m(mountpoint, force_remount, timeout_ms, readonly)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmount\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmountpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforce_remount\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout_ms\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m120000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreadonly\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;34m\"\"\"Mount your Google Drive at the specified mountpoint path.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m return _mount(\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0mmountpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mforce_remount\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mforce_remount\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/drive.py\u001b[0m in \u001b[0;36m_mount\u001b[0;34m(mountpoint, force_remount, timeout_ms, ephemeral, readonly)\u001b[0m\n\u001b[1;32m 130\u001b[0m )\n\u001b[1;32m 131\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mephemeral\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m _message.blocking_request(\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;34m'request_auth'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'authType'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'dfs_ephemeral'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout_sec\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m )\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/_message.py\u001b[0m in \u001b[0;36mblocking_request\u001b[0;34m(request_type, request, timeout_sec, parent)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0mrequest_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpect_reply\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m )\n\u001b[0;32m--> 176\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mread_reply_from_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrequest_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout_sec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/google/colab/_message.py\u001b[0m in \u001b[0;36mread_reply_from_input\u001b[0;34m(message_id, timeout_sec)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0mreply\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_read_next_input_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreply\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_NOT_READY\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreply\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.025\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m if (\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"cell_type": "code",
"source": [
"zip_path = '/content/drive/MyDrive/WCEBleedGen.zip'\n",
"!unzip $zip_path -d MISAHUB"
],
"metadata": {
"id": "f4wAqXaHepk7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import shutil\n",
"\n",
"images_folder = '/content/MISAHUB/WCEBleedGen/bleeding/Images'\n",
"annotations_folder = '/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/YOLO_TXT'\n",
"output_folder = '/content/MISAHUB/train/'\n",
"\n",
"# Create the 'train' folder if it doesn't exist\n",
"if not os.path.exists(output_folder):\n",
" os.makedirs(output_folder)\n",
"\n",
"# Create 'images' and 'labels' folders inside 'train'\n",
"images_output_folder = os.path.join(output_folder, 'images')\n",
"labels_output_folder = os.path.join(output_folder, 'labels')\n",
"\n",
"os.makedirs(images_output_folder, exist_ok=True)\n",
"os.makedirs(labels_output_folder, exist_ok=True)\n",
"\n",
"# Get a list of files in both folders\n",
"image_files = os.listdir(images_folder)\n",
"annotation_files = os.listdir(annotations_folder)\n",
"\n",
"# Ensure only files with the same base name are considered\n",
"image_files = [file for file in image_files if file.endswith('.png')]\n",
"annotation_files = [file for file in annotation_files if file.endswith('.txt')]\n",
"\n",
"# Sort the files to ensure consistent numbering\n",
"image_files.sort()\n",
"annotation_files.sort()\n",
"\n",
"# Rename and move the files to the 'train' folder\n",
"for i, (image_file, annotation_file) in enumerate(zip(image_files, annotation_files), start=1):\n",
" # Define the new names for the image and annotation\n",
" new_image_name = f\"image_{i}.png\"\n",
" new_annotation_name = f\"image_{i}.txt\"\n",
"\n",
" # Rename and move the image file\n",
" old_image_path = os.path.join(images_folder, image_file)\n",
" new_image_path = os.path.join(images_output_folder, new_image_name)\n",
" os.rename(old_image_path, new_image_path)\n",
"\n",
" # Rename and move the annotation file\n",
" old_annotation_path = os.path.join(annotations_folder, annotation_file)\n",
" new_annotation_path = os.path.join(labels_output_folder, new_annotation_name)\n",
" os.rename(old_annotation_path, new_annotation_path)\n",
"\n",
" print(f\"Renamed and moved: {image_file} to {new_image_name}\")\n",
" print(f\"Renamed and moved: {annotation_file} to {new_annotation_name}\")\n",
"\n",
"print(\"All files have been renamed and moved to the 'train' folder.\")\n"
],
"metadata": {
"id": "Wc8Sl_fUfP1y"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import shutil\n",
"\n",
"images_folder = '/content/MISAHUB/WCEBleedGen/non-bleeding/images'\n",
"output_folder = '/content/MISAHUB/train/'\n",
"\n",
"# Create the 'train' folder if it doesn't exist\n",
"if not os.path.exists(output_folder):\n",
" os.makedirs(output_folder)\n",
"\n",
"# Create 'images' and 'labels' folders inside 'train'\n",
"images_output_folder = os.path.join(output_folder, 'images')\n",
"labels_output_folder = os.path.join(output_folder, 'labels')\n",
"\n",
"os.makedirs(images_output_folder, exist_ok=True)\n",
"os.makedirs(labels_output_folder, exist_ok=True)\n",
"\n",
"# Get a list of files in the 'images' folder\n",
"image_files = os.listdir(images_folder)\n",
"image_files = [file for file in image_files if file.endswith('.png')]\n",
"image_files.sort()\n",
"\n",
"# Rename and move the files to the 'train' folder\n",
"for i, image_file in enumerate(image_files, start=1310):\n",
" # Define the new names for the image and annotation\n",
" new_image_name = f\"images_{i}.png\"\n",
" new_annotation_name = f\"images_{i}.txt\"\n",
"\n",
" # Rename and move the image file\n",
" old_image_path = os.path.join(images_folder, image_file)\n",
" new_image_path = os.path.join(images_output_folder, new_image_name)\n",
" os.rename(old_image_path, new_image_path)\n",
"\n",
" # Create an empty label file in the 'labels' folder\n",
" new_annotation_path = os.path.join(labels_output_folder, new_annotation_name)\n",
" with open(new_annotation_path, 'w') as empty_file:\n",
" pass\n",
"\n",
" print(f\"Renamed and moved: {image_file} to {new_image_name}\")\n",
" print(f\"Created empty label file: {new_annotation_name}\")\n",
"\n",
"print(\"All files have been renamed and moved to the 'train' folder.\")"
],
"metadata": {
"id": "ukgpuDXAVy3j"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"! ls MISAHUB/newTrain/bleeding | wc -l"
],
"metadata": {
"id": "UavS7_rSVe87"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import shutil\n",
"\n",
"# Define the paths to the 'train' directory and its subdirectories\n",
"train_dir = 'MISAHUB/train' # Update this with the actual path to your 'train' directory\n",
"images_dir = os.path.join(train_dir, 'images')\n",
"labels_dir = os.path.join(train_dir, 'labels')\n",
"train_dir = 'MISAHUB/newTrain'\n",
"# Create output directories for the two classes\n",
"bleeding_dir = os.path.join(train_dir, 'bleeding')\n",
"non_bleeding_dir = os.path.join(train_dir, 'non_bleeding')\n",
"\n",
"# Create the 'bleeding' and 'non_bleeding' directories if they don't exist\n",
"os.makedirs(bleeding_dir, exist_ok=True)\n",
"os.makedirs(non_bleeding_dir, exist_ok=True)\n",
"\n",
"# Iterate through the files in the 'images' directory\n",
"for image_filename in os.listdir(images_dir):\n",
" # Form the corresponding label file path\n",
" label_filename = os.path.splitext(image_filename)[0] + '.txt'\n",
" label_filepath = os.path.join(labels_dir, label_filename)\n",
"\n",
" # Check if the label file is empty (indicating non-bleeding)\n",
" if os.path.exists(label_filepath) and os.path.getsize(label_filepath) == 0:\n",
" # Move the image to the 'non_bleeding' directory\n",
" shutil.move(os.path.join(images_dir, image_filename), os.path.join(non_bleeding_dir, image_filename))\n",
" else:\n",
" # Move the image to the 'bleeding' directory\n",
" shutil.move(os.path.join(images_dir, image_filename), os.path.join(bleeding_dir, image_filename))\n",
"\n",
"print(\"Dataset organized into 'bleeding' and 'non_bleeding' classes.\")\n"
],
"metadata": {
"id": "mm2rQtsDWwP_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import random\n",
"import shutil\n",
"\n",
"# Define the paths to the 'bleeding' and 'non_bleeding' directories\n",
"bleeding_dir = '/content/MISAHUB/newTrain/bleeding' # Update with the actual path to your 'bleeding' directory\n",
"non_bleeding_dir = '/content/MISAHUB/newTrain/non_bleeding' # Update with the actual path to your 'non-bleeding' directory\n",
"\n",
"# Define the path to the 'test' directory\n",
"test_dir = '/content/MISAHUB/test' # Update with the path where you want to create the 'test' directory\n",
"\n",
"# Create the 'test' directory if it doesn't exist\n",
"os.makedirs(test_dir, exist_ok=True)\n",
"\n",
"# Define the number of files to randomly select from each directory\n",
"num_files_to_select = 20\n",
"\n",
"# Randomly select 20 files from the 'bleeding' directory\n",
"bleeding_files = os.listdir(bleeding_dir)\n",
"selected_bleeding_files = random.sample(bleeding_files, num_files_to_select)\n",
"\n",
"# Randomly select 20 files from the 'non_bleeding' directory\n",
"non_bleeding_files = os.listdir(non_bleeding_dir)\n",
"selected_non_bleeding_files = random.sample(non_bleeding_files, num_files_to_select)\n",
"\n",
"# Copy the selected 'bleeding' files to the 'test' directory with new names\n",
"for filename in selected_bleeding_files:\n",
" source_filepath = os.path.join(bleeding_dir, filename)\n",
" new_filename = f'bleeding_{filename}'\n",
" dest_filepath = os.path.join(test_dir, new_filename)\n",
" shutil.copy(source_filepath, dest_filepath)\n",
" # Remove the file from the 'bleeding' directory\n",
" os.remove(source_filepath)\n",
"\n",
"# Copy the selected 'non_bleeding' files to the 'test' directory with new names\n",
"for filename in selected_non_bleeding_files:\n",
" source_filepath = os.path.join(non_bleeding_dir, filename)\n",
" new_filename = f'non_bleeding_{filename}'\n",
" dest_filepath = os.path.join(test_dir, new_filename)\n",
" shutil.copy(source_filepath, dest_filepath)\n",
" # Remove the file from the 'non_bleeding' directory\n",
" os.remove(source_filepath)\n",
"\n",
"print(f\"Selected and copied {num_files_to_select} files from 'bleeding' and 'non-bleeding' to the 'test' directory, and removed them from the original directories.\")\n"
],
"metadata": {
"id": "rq7U71RIWwRI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import PIL\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from tensorflow.python.keras.layers import Dense, Flatten, Dropout\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.optimizers import Adam"
],
"metadata": {
"id": "0gSmVvwUfp2u"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from pathlib import Path\n",
"dataset_dir = 'MISAHUB/newTrain'\n",
"dataset_dir = Path(dataset_dir)\n",
"bleeding = list(dataset_dir.glob('bleeding/*'))\n",
"print(bleeding[0])\n",
"image = PIL.Image.open(str(bleeding[0]))\n",
"width, height = image.size\n",
"print(width, height)"
],
"metadata": {
"id": "M6N5_xdqqz3V"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from tensorflow.keras.preprocessing import image_dataset_from_directory\n",
"\n",
"# Define the path to the 'train' directory\n",
"train_dir = 'MISAHUB/newTrain'\n",
"img_height,img_width=224,224\n",
"batch_size=32\n",
"train_dataset = image_dataset_from_directory(\n",
" train_dir,\n",
" labels='inferred',\n",
" label_mode='binary',\n",
" batch_size=32, # Adjust batch size as needed\n",
" image_size=(224, 224), # Specify the desired image size\n",
" shuffle=True,\n",
" seed=123,\n",
" validation_split=0.2, # Split a portion of data for validation\n",
" subset='training' # Use 'training' subset for training data\n",
")\n",
"\n",
"# Create the validation dataset with the same validation split\n",
"validation_dataset = image_dataset_from_directory(\n",
" train_dir,\n",
" labels='inferred',\n",
" label_mode='binary',\n",
" batch_size=32, # Adjust batch size as needed\n",
" image_size=(224, 224), # Specify the desired image size\n",
" shuffle=True,\n",
" seed=123,\n",
" validation_split=0.2, # Use the same validation split as for training\n",
" subset='validation' # Use 'validation' subset for validation data\n",
")"
],
"metadata": {
"id": "IitcVWIWqrzv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class_names = train_dataset.class_names\n",
"print(class_names)"
],
"metadata": {
"id": "hz3H4ykAsZLv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Classification ResNet**"
],
"metadata": {
"id": "AewhITJMgvWs"
}
},
{
"cell_type": "code",
"source": [
"from tensorflow.keras.regularizers import l2\n",
"resnet_model = Sequential()\n",
"\n",
"pretrained_model= tf.keras.applications.ResNet50(include_top=False,\n",
" input_shape=(224,224,3),\n",
" pooling='max',classes=2,\n",
" weights='imagenet')\n",
"# for layer in pretrained_model.layers:\n",
"# layer.trainable=False\n",
"# Unfreeze specific layers\n",
"for layer in pretrained_model.layers:\n",
" if layer.name in ['block4_conv1', 'block4_conv2', 'block4_conv3']:\n",
" layer.trainable = True\n",
" else:\n",
" layer.trainable = False\n",
"\n",
"# Compile the model after unfreezing layers\n",
"# resnet_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
"\n",
"\n",
"resnet_model.add(pretrained_model)\n",
"resnet_model.add(Flatten())\n",
"resnet_model.add(Dense(512, activation='relu', kernel_regularizer=l2(0.02)))\n",
"resnet_model.add(Dropout(0.5))\n",
"resnet_model.add(Dense(1, activation='sigmoid',kernel_regularizer=l2(0.02)))"
],
"metadata": {
"id": "PiWWXwTa0hAK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"resnet_model.summary()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ocxlltpa09LS",
"outputId": "9254550c-38da-4539-d760-37d1f2ba62ef"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"sequential_4\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" resnet50 (Functional) (None, 2048) 23587712 \n",
" \n",
" module_wrapper_9 (ModuleWr (None, 2048) 0 \n",
" apper) \n",
" \n",
" module_wrapper_10 (ModuleW (None, 512) 1049088 \n",
" rapper) \n",
" \n",
" module_wrapper_11 (ModuleW (None, 512) 0 \n",
" rapper) \n",
" \n",
" module_wrapper_12 (ModuleW (None, 1) 513 \n",
" rapper) \n",
" \n",
"=================================================================\n",
"Total params: 24637313 (93.98 MB)\n",
"Trainable params: 1049601 (4.00 MB)\n",
"Non-trainable params: 23587712 (89.98 MB)\n",
"_________________________________________________________________\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"resnet_model.compile(optimizer=Adam(learning_rate=0.001),loss='binary_crossentropy',metrics=['accuracy'])"
],
"metadata": {
"id": "EKycvVcE1A7M"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"epochs=5\n",
"history = resnet_model.fit(\n",
" train_dataset,\n",
" validation_data=validation_dataset,\n",
" epochs=epochs\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8ZU2ADPF1Oic",
"outputId": "96b6e10e-c108-4e99-9529-6cc0c91b136a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/5\n",
"65/65 [==============================] - 14s 152ms/step - loss: 3.4739 - accuracy: 0.7551 - val_loss: 0.1862 - val_accuracy: 0.9204\n",
"Epoch 2/5\n",
"65/65 [==============================] - 9s 138ms/step - loss: 0.1884 - accuracy: 0.9180 - val_loss: 0.1198 - val_accuracy: 0.9515\n",
"Epoch 3/5\n",
"65/65 [==============================] - 9s 130ms/step - loss: 0.1345 - accuracy: 0.9467 - val_loss: 0.0959 - val_accuracy: 0.9689\n",
"Epoch 4/5\n",
"65/65 [==============================] - 9s 133ms/step - loss: 0.0903 - accuracy: 0.9656 - val_loss: 0.0777 - val_accuracy: 0.9689\n",
"Epoch 5/5\n",
"65/65 [==============================] - 9s 138ms/step - loss: 0.0717 - accuracy: 0.9762 - val_loss: 0.1072 - val_accuracy: 0.9534\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"\n",
"# Define the path to the 'test' directory\n",
"test_dir = 'MISAHUB/test' # Update with the actual path to your 'test' directory\n",
"\n",
"# Get a list of all image files in the 'test' directory\n",
"test_image_files = [os.path.join(test_dir, filename) for filename in os.listdir(test_dir)]\n",
"\n",
"# Iterate through the 'test' images\n",
"for image_file in test_image_files:\n",
" # Load and preprocess the image\n",
" img = tf.keras.preprocessing.image.load_img(image_file, target_size=(224, 224))\n",
" img = tf.keras.preprocessing.image.img_to_array(img)\n",
" img = tf.expand_dims(img, axis=0)\n",
"\n",
" # Make predictions using the model\n",
" predictions = resnet_model.predict(img)\n",
"\n",
" # Get the class label (assuming binary classification with 0 and 1)\n",
" class_label = \"bleeding\" if predictions[0][0] > 0.5 else \"non-bleeding\"\n",
"\n",
" # Print the file name and predicted class\n",
" print(f\"File: {os.path.basename(image_file)}, Predicted Class: {class_label}\")\n"
],
"metadata": {
"id": "2oOoHdvTboRC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "yYm_a_cHOqc1"
}
},
{
"cell_type": "markdown",
"source": [
"# VGG"
],
"metadata": {
"id": "vQ6Cl2SgOqd1"
}
},
{
"cell_type": "code",
"source": [
"vgg = Sequential()\n",
"\n",
"pretrained = tf.keras.applications.vgg16.VGG16(\n",
" include_top=False,\n",
" weights='imagenet',\n",
" # input_tensor=None,\n",
" input_shape=(224,224,3),\n",
" pooling='max',\n",
" classes=2,\n",
" classifier_activation='softmax'\n",
")"
],
"metadata": {
"id": "FXbtBxYLOsHS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Evaluating Model**"
],
"metadata": {
"id": "eSg6Sj0Y3RyR"
}
},
{
"cell_type": "code",
"source": [
"fig1 = plt.gcf()\n",
"plt.plot(history.history['accuracy'])\n",
"plt.plot(history.history['val_accuracy'])\n",
"plt.axis(ymin=0.4,ymax=1)\n",
"plt.grid()\n",
"plt.title('Model Accuracy')\n",
"plt.ylabel('Accuracy')\n",
"plt.xlabel('Epochs')\n",
"plt.legend(['train', 'validation'])\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 472
},
"id": "NwjguSD73RMv",
"outputId": "b8dff6ab-9ee4-4205-b7b3-a6cfdd3ad9fa"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# **Making Prediction**"
],
"metadata": {
"id": "z2O8aVxV3aUp"
}
},
{
"cell_type": "code",
"source": [
"import cv2\n",
"\n",
"dataset_dir = Path(dataset_dir)\n",
"nonbleeding = list(dataset_dir.glob('non-bleeding/*'))\n",
"for x in range(50):\n",
" image=cv2.imread(str(nonbleeding[x]))\n",
" image_resized= cv2.resize(image, (img_height,img_width))\n",
" image=np.expand_dims(image_resized,axis=0)\n",
" pred=resnet_model.predict(image)\n",
" print(pred)\n",
"# print(image.shape)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-Ar39gRs3ZyE",
"outputId": "98e02392-e4c4-49bf-c4f1-14554dd029ca"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1/1 [==============================] - 0s 25ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 23ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 23ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 27ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 25ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 25ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 37ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 23ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 27ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 23ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 25ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 25ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 23ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 22ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 29ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 23ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 25ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 28ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 22ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 27ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 22ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 28ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 28ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 30ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 25ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 28ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 34ms/step\n",
"[[0.]]\n",
"1/1 [==============================] - 0s 27ms/step\n",
"[[0.]]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"pred=resnet_model.predict(image)\n",
"print(pred)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "l9d4t_tV3eOn",
"outputId": "b829d154-ecdb-4902-fa96-c86c3b898005"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1/1 [==============================] - 0s 24ms/step\n",
"[[0.]]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"output_class=class_names[np.argmax(pred)]\n",
"print(\"The predicted class is\", output_class)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LpWEmqdG3fP-",
"outputId": "09c6c4a1-0091-47c4-bbc3-b34aadc94e8c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The predicted class is bleeding\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# **Playground**"
],
"metadata": {
"id": "5gMQ7yIDg3Kz"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"import xml.etree.ElementTree as ET\n",
"\n",
"# Directory containing XML files\n",
"xml_folder = \"/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/XML\"\n",
"\n",
"# Directory to save TXT files\n",
"txt_folder = \"/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/TXT\"\n",
"\n",
"# Create the TXT folder if it doesn't exist\n",
"os.makedirs(txt_folder, exist_ok=True)\n",
"\n",
"# Iterate through XML files in the folder\n",
"for xml_filename in os.listdir(xml_folder):\n",
" if xml_filename.endswith(\".xml\"):\n",
" # Construct the full paths for XML and TXT files\n",
" xml_path = os.path.join(xml_folder, xml_filename)\n",
" txt_filename = os.path.splitext(xml_filename)[0] + \".txt\"\n",
" txt_path = os.path.join(txt_folder, txt_filename)\n",
"\n",
" # Parse the XML file\n",
" tree = ET.parse(xml_path)\n",
" root = tree.getroot()\n",
"\n",
" # Open the TXT file for writing\n",
" with open(txt_path, \"w\") as txt_file:\n",
" # Loop through object elements in the XML\n",
" for obj in root.findall('object'):\n",
" # Extract bounding box coordinates\n",
" xmin = int(obj.find('bndbox/xmin').text)\n",
" ymin = int(obj.find('bndbox/ymin').text)\n",
" xmax = int(obj.find('bndbox/xmax').text)\n",
" ymax = int(obj.find('bndbox/ymax').text)\n",
"\n",
" # Write the coordinates to the TXT file\n",
" txt_file.write(f\"{xmin} {ymin} {xmax} {ymax}\\n\")\n",
"\n",
" print(f\"Converted {xml_filename} to {txt_filename}\")\n",
"\n",
"print(\"Conversion complete.\")\n"
],
"metadata": {
"id": "nDFeGPb1YM3q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import shutil\n",
"\n",
"images_folder = '/content/MISAHUB/WCEBleedGen/bleeding/Images'\n",
"annotations_folder = '/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/TXT'\n",
"output_folder = '/content/MISAHUB/train/'\n",
"\n",
"\n",
"# Create the 'train' folder if it doesn't exist\n",
"if not os.path.exists(output_folder):\n",
" os.makedirs(output_folder)\n",
"\n",
"# Create 'images' and 'labels' folders inside 'train'\n",
"images_output_folder = os.path.join(output_folder, 'images')\n",
"labels_output_folder = os.path.join(output_folder, 'labels')\n",
"\n",
"os.makedirs(images_output_folder, exist_ok=True)\n",
"os.makedirs(labels_output_folder, exist_ok=True)\n",
"\n",
"# Get a list of files in both folders\n",
"image_files = os.listdir(images_folder)\n",
"annotation_files = os.listdir(annotations_folder)\n",
"\n",
"# Ensure only files with the same base name are considered\n",
"image_files = [file for file in image_files if file.endswith('.png')]\n",
"annotation_files = [file for file in annotation_files if file.endswith('.txt')]\n",
"\n",
"# Sort the files to ensure consistent numbering\n",
"image_files.sort()\n",
"annotation_files.sort()\n",
"\n",
"# Rename and move the files to the 'train' folder\n",
"for i, (image_file, annotation_file) in enumerate(zip(image_files, annotation_files), start=1):\n",
" # Define the new names for the image and annotation\n",
" new_image_name = f\"image_{i}.png\"\n",
" new_annotation_name = f\"image_{i}.txt\"\n",
"\n",
" # Rename and move the image file\n",
" old_image_path = os.path.join(images_folder, image_file)\n",
" new_image_path = os.path.join(images_output_folder, new_image_name)\n",
" os.rename(old_image_path, new_image_path)\n",
"\n",
" # Rename and move the annotation file\n",
" old_annotation_path = os.path.join(annotations_folder, annotation_file)\n",
" new_annotation_path = os.path.join(labels_output_folder, new_annotation_name)\n",
" os.rename(old_annotation_path, new_annotation_path)\n",
"\n",
" print(f\"Renamed and moved: {image_file} to {new_image_name}\")\n",
" print(f\"Renamed and moved: {annotation_file} to {new_annotation_name}\")\n",
"\n",
"print(\"All files have been renamed and moved to the 'train' folder.\")\n"
],
"metadata": {
"id": "WsiKxYWre04r"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import shutil\n",
"\n",
"images_folder = '/content/MISAHUB/WCEBleedGen/non-bleeding/images'\n",
"output_folder = '/content/MISAHUB/train/'\n",
"\n",
"# Create the 'train' folder if it doesn't exist\n",
"if not os.path.exists(output_folder):\n",
" os.makedirs(output_folder)\n",
"\n",
"# Create 'images' and 'labels' folders inside 'train'\n",
"images_output_folder = os.path.join(output_folder, 'images')\n",
"labels_output_folder = os.path.join(output_folder, 'labels')\n",
"\n",
"os.makedirs(images_output_folder, exist_ok=True)\n",
"os.makedirs(labels_output_folder, exist_ok=True)\n",
"\n",
"# Get a list of files in the 'images' folder\n",
"image_files = os.listdir(images_folder)\n",
"image_files = [file for file in image_files if file.endswith('.png')]\n",
"image_files.sort()\n",
"\n",
"# Rename and move the files to the 'train' folder\n",
"for i, image_file in enumerate(image_files, start=1310):\n",
" # Define the new names for the image and annotation\n",
" new_image_name = f\"images_{i}.png\"\n",
" new_annotation_name = f\"images_{i}.txt\"\n",
"\n",
" # Rename and move the image file\n",
" old_image_path = os.path.join(images_folder, image_file)\n",
" new_image_path = os.path.join(images_output_folder, new_image_name)\n",
" os.rename(old_image_path, new_image_path)\n",
"\n",
" # Create an empty label file in the 'labels' folder\n",
" new_annotation_path = os.path.join(labels_output_folder, new_annotation_name)\n",
" with open(new_annotation_path, 'w') as empty_file:\n",
" pass\n",
"\n",
" print(f\"Renamed and moved: {image_file} to {new_image_name}\")\n",
" print(f\"Created empty label file: {new_annotation_name}\")\n",
"\n",
"print(\"All files have been renamed and moved to the 'train' folder.\")\n"
],
"metadata": {
"id": "ua4rRD38gAEI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!rm -r train_f\n",
"!rm -r test_f\n",
"!rm -r val_f\n",
"import os\n",
"import shutil\n",
"import random\n",
"\n",
"# Define the paths to the source images and labels directories\n",
"source_images_dir = '/content/MISAHUB/train/images'\n",
"source_labels_dir = '/content/MISAHUB/train/labels'\n",
"\n",
"train_ratio = 0.6\n",
"test_ratio = 0.15\n",
"val_ratio=0.25\n",
"\n",
"\n",
"# Define the paths for the train, validation, and test directories\n",
"train_dir = 'train_f'\n",
"val_dir = 'val_f'\n",
"test_dir = 'test_f'\n",
"\n",
"# Define subdirectories for images and labels within train, val, and test\n",
"subdirectories = ['images', 'labels']\n",
"\n",
"# Create the target directories and subdirectories if they don't exist\n",
"for directory in [train_dir, val_dir, test_dir]:\n",
" os.makedirs(directory, exist_ok=True)\n",
" for subdirectory in subdirectories:\n",
" os.makedirs(os.path.join(directory, subdirectory), exist_ok=True)\n",
"\n",
"# List all image files in the source directory\n",
"image_files = [file for file in os.listdir(source_images_dir) if file.endswith('.png')]\n",
"\n",
"# Shuffle the list of image files\n",
"random.shuffle(image_files)\n",
"\n",
"# Calculate the number of files for each split\n",
"total_files = len(image_files)\n",
"num_train = int(total_files * train_ratio)\n",
"num_val = int(total_files * val_ratio)\n",
"\n",
"# Split the image files into train, validation, and test sets\n",
"train_files = image_files[:num_train]\n",
"val_files = image_files[num_train:num_train + num_val]\n",
"test_files = image_files[num_train + num_val:]\n",
"\n",
"# Copy the corresponding label files to the respective directories\n",
"for file in train_files:\n",
" base_name = os.path.splitext(file)[0]\n",
" label_file = f\"{base_name}.txt\"\n",
" shutil.copy(os.path.join(source_labels_dir, label_file), os.path.join(train_dir, 'labels', label_file))\n",
"\n",
"for file in val_files:\n",
" base_name = os.path.splitext(file)[0]\n",
" label_file = f\"{base_name}.txt\"\n",
" shutil.copy(os.path.join(source_labels_dir, label_file), os.path.join(val_dir, 'labels', label_file))\n",
"\n",
"for file in test_files:\n",
" base_name = os.path.splitext(file)[0]\n",
" label_file = f\"{base_name}.txt\"\n",
" shutil.copy(os.path.join(source_labels_dir, label_file), os.path.join(test_dir, 'labels', label_file))\n",
"\n",
"# Copy the image files to their respective directories\n",
"for file in train_files:\n",
" shutil.copy(os.path.join(source_images_dir, file), os.path.join(train_dir, 'images', file))\n",
"\n",
"for file in val_files:\n",
" shutil.copy(os.path.join(source_images_dir, file), os.path.join(val_dir, 'images', file))\n",
"\n",
"for file in test_files:\n",
" shutil.copy(os.path.join(source_images_dir, file), os.path.join(test_dir, 'images', file))\n",
"\n",
"print(\"Splitting complete.\")\n"
],
"metadata": {
"id": "2DGCW0-xgExp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from torchvision import models\n",
"import torch"
],
"metadata": {
"id": "GS8g5pqJgQP8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"torch.cuda.is_available()"
],
"metadata": {
"id": "S772wWxHif9X",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c8c2f915-50d8-4ecf-ab92-82f7b001f5d5"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"metadata": {
"id": "8t7u-bkOiiEO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"from PIL import Image\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms as transforms\n",
"\n",
"# Define a custom dataset class\n",
"class CustomDataset(Dataset):\n",
" def __init__(self, image_dir, label_dir, transform=None):\n",
" self.image_dir = image_dir\n",
" self.label_dir = label_dir\n",
" self.transform = transform\n",
"\n",
" # List all image and label file names\n",
" self.image_files = os.listdir(self.image_dir)\n",
" self.label_files = os.listdir(self.label_dir)\n",
"\n",
" def __len__(self):\n",
" return len(self.image_files)\n",
"\n",
" def __getitem__(self, idx):\n",
" # Load image\n",
" img_name = os.path.join(self.image_dir, self.image_files[idx])\n",
" image = Image.open(img_name)\n",
"\n",
" # Load label (assuming labels are in a text file with one label per line)\n",
" label_name = os.path.join(self.label_dir, self.label_files[idx])\n",
" with open(label_name, 'r') as label_file:\n",
" label = label_file.readline().strip() # Read the first line as the label\n",
"\n",
" # Apply transformations to the image if provided\n",
" if self.transform:\n",
" image = self.transform(image)\n",
"\n",
" # You can process the label further if needed\n",
"\n",
" return image, label\n",
"\n",
"# Define transformations for image preprocessing\n",
"transform = transforms.Compose([\n",
" transforms.Resize((224, 224)), # Resize to 224x224\n",
" transforms.ToTensor(), # Convert to a PyTorch tensor\n",
" transforms.Normalize(\n",
" mean=[0.485, 0.456, 0.406], # Mean values for ImageNet data\n",
" std=[0.229, 0.224, 0.225] # Standard deviations for ImageNet data\n",
" )\n",
"])\n",
"\n",
"# Define paths to your image and label directories\n",
"image_directory = \"/content/train_f/images\"\n",
"label_directory = \"/content/train_f/labels\"\n",
"\n",
"# Create an instance of the custom dataset\n",
"custom_dataset = CustomDataset(image_directory, label_directory, transform=transform)\n",
"\n",
"# Create a data loader for batch processing\n",
"batch_size = 32 # Adjust the batch size as needed\n",
"train_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True, num_workers=2)\n"
],
"metadata": {
"id": "gDTaaSguiKLn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.models as models\n",
"\n",
"# Load pre-trained ResNet-18 model\n",
"resnet = models.resnet18(pretrained=True)\n",
"\n",
"# Remove the original classification head\n",
"modules = list(resnet.children())[:-1]\n",
"resnet = nn.Sequential(*modules)\n",
"\n",
"# Define custom classification head for bleeding vs. non-bleeding\n",
"classification_head = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(512, 2), # 2 output units for binary classification (bleeding vs. non-bleeding)\n",
" nn.Softmax(dim=1) # Apply softmax activation for classification\n",
")\n",
"\n",
"# Define custom bounding box regression head\n",
"bounding_box_head = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(512, 4) # 4 output units for (x, y, width, height) coordinates of the bounding box\n",
")\n",
"\n",
"# Create the modified model\n",
"class ResNetModified(nn.Module):\n",
" def __init__(self, resnet, classification_head, bounding_box_head):\n",
" super(ResNetModified, self).__init__()\n",
" self.resnet = resnet\n",
" self.classification_head = classification_head\n",
" self.bounding_box_head = bounding_box_head\n",
"\n",
" def forward(self, x):\n",
" features = self.resnet(x)\n",
" classification_output = self.classification_head(features)\n",
" bounding_box_output = self.bounding_box_head(features)\n",
" return classification_output, bounding_box_output\n",
"\n",
"# Instantiate the modified model\n",
"modified_resnet = ResNetModified(resnet, classification_head, bounding_box_head)\n"
],
"metadata": {
"id": "ly_ZxMpJl--P",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "35f3c3a2-6bac-481b-e4c9-4f9f9071d82d"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n",
"Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
"100%|██████████| 44.7M/44.7M [00:00<00:00, 77.0MB/s]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"import json\n",
"import torch\n",
"from torchvision.transforms.functional import to_tensor\n",
"\n",
"# Path to the labels folder\n",
"labels_folder = \"/content/train_f/labels\"\n",
"\n",
"# Path to the corresponding image folder\n",
"image_folder = \"/content/train_f/images\"\n",
"\n",
"# Initialize an empty list to store bounding box targets\n",
"bounding_box_targets = []\n",
"\n",
"# List all files in the labels folder\n",
"label_files = os.listdir(labels_folder)\n",
"\n",
"# Iterate over the label files\n",
"for label_file in label_files:\n",
"\n",
" label_file_path = os.path.join(labels_folder, label_file)\n",
"\n",
" # Load the bounding box coordinates from the label file\n",
" with open(label_file_path, 'r') as f:\n",
" label_data = f.readlines() # Assuming labels are stored in JSON format\n",
" # print(label_data)\n",
" # print(label_file_path)\n",
" for line in label_data:\n",
" # Split the line into individual coordinates as strings\n",
" # print(\"Hello: \"+str(line.strip().split()))\n",
" x_min_str, y_min_str, x_max_str, y_max_str = line.strip().split()\n",
"\n",
" # Convert the coordinates to integers\n",
" x_min = int(x_min_str)\n",
" y_min = int(y_min_str)\n",
" x_max = int(x_max_str)\n",
" y_max = int(y_max_str)\n",
"\n",
" # Append the bounding box coordinates to the list\n",
" bounding_box_targets.append([x_min, y_min, x_max, y_max])\n",
"\n",
"# Convert the list of bounding box targets to a tensor\n",
"bounding_box_targets = torch.tensor(bounding_box_targets, dtype=torch.float32)\n"
],
"metadata": {
"id": "5lcS0ZnLHFF_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"\n",
"# Define your custom loss functions for classification and bounding box regression\n",
"classification_criterion = nn.CrossEntropyLoss()\n",
"bounding_box_criterion = nn.MSELoss()\n",
"\n",
"# Define an optimizer (e.g., SGD or Adam)\n",
"optimizer = optim.SGD(modified_resnet.parameters(), lr=0.001, momentum=0.9)\n",
"\n",
"# Set the model to training mode\n",
"modified_resnet.train()\n",
"\n",
"# Define the number of training epochs\n",
"num_epochs = 10 # You can adjust this number based on your dataset and training performance\n",
"\n",
"# Training loop\n",
"for epoch in range(num_epochs):\n",
" running_loss = 0.0\n",
"\n",
" # Iterate over the mini-batches\n",
" for i, data in enumerate(train_loader, 0):\n",
" inputs, labels, bounding_box_targets = data # You'll need to structure your dataset accordingly\n",
"\n",
" # Zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # Forward pass\n",
" classification_output, bounding_box_output = modified_resnet(inputs)\n",
"\n",
" # Calculate classification loss\n",
" classification_loss = classification_criterion(classification_output, labels)\n",
"\n",
" # Calculate bounding box regression loss\n",
" bounding_box_loss = bounding_box_criterion(bounding_box_output, bounding_box_targets)\n",
"\n",
" # Total loss is a combination of classification and bounding box loss\n",
" total_loss = classification_loss + bounding_box_loss\n",
"\n",
" # Backpropagation\n",
" total_loss.backward()\n",
"\n",
" # Update weights\n",
" optimizer.step()\n",
"\n",
" # Print statistics\n",
" running_loss += total_loss.item()\n",
" if i % 100 == 99: # Print every 100 mini-batches\n",
" print(f\"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 100:.3f}\")\n",
" running_loss = 0.0\n",
"\n",
"print(\"Finished Training\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 245
},
"id": "8EI4THBHidyZ",
"outputId": "5cc6e12d-4e60-497d-9b9c-0a61f6ed3a1a"
},
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-19-9bb691bff76c>\u001b[0m in \u001b[0;36m<cell line: 20>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;31m# Iterate over the mini-batches\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbounding_box_targets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;31m# You'll need to structure your dataset accordingly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;31m# Zero the parameter gradients\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 3, got 2)"
]
}
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"import torch\n",
"\n",
"# Path to the labels folder (containing text files)\n",
"labels_folder = \"/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/XML/img- (1176).xml\"\n",
"\n",
"# Initialize an empty list to store bounding box targets\n",
"bounding_box_targets = []\n",
"\n",
"\n",
"# Iterate over the label files\n",
"# for label_file in label_files:\n",
" # label_file_path = os.path.join(labels_folder, label_file)\n",
"\n",
" # Read the bounding box labels from the text file\n",
"with open(labels_folder, 'r') as f:\n",
" lines = f.readlines()\n",
"print(lines)\n",
"for line in lines:\n",
" # Split the line into individual coordinates as strings\n",
" x_min_str, y_min_str, x_max_str, y_max_str = line.strip().split()\n",
"\n",
" # Convert the coordinates to integers\n",
" x_min = int(x_min_str)\n",
" y_min = int(y_min_str)\n",
" x_max = int(x_max_str)\n",
" y_max = int(y_max_str)\n",
"\n",
" # Append the bounding box coordinates to the list\n",
" bounding_box_targets.append([x_min, y_min, x_max, y_max])\n",
"\n",
"# Convert the list of bounding box targets to a tensor\n",
"bounding_box_targets = torch.tensor(bounding_box_targets, dtype=torch.float32)\n",
"\n",
"# Print the resulting bounding box targets tensor\n",
"print(bounding_box_targets)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 282
},
"id": "ju7fsfnySmo4",
"outputId": "f11d0cd7-34b2-4138-c4fb-ca564ee169a1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['<annotation>\\n', '\\t<folder>XML</folder>\\n', '\\t<filename>img- (1176).png</filename>\\n', '\\t<path>D:\\\\WCE\\\\dataset final\\\\everythingYOLO\\\\img- (1176).png</path>\\n', '\\t<source>\\n', '\\t\\t<database>Unknown</database>\\n', '\\t</source>\\n', '\\t<size>\\n', '\\t\\t<width>224</width>\\n', '\\t\\t<height>224</height>\\n', '\\t\\t<depth>3</depth>\\n', '\\t</size>\\n', '\\t<segmented>0</segmented>\\n', '\\t<object>\\n', '\\t\\t<name>bleeding</name>\\n', '\\t\\t<pose>Unspecified</pose>\\n', '\\t\\t<truncated>0</truncated>\\n', '\\t\\t<difficult>0</difficult>\\n', '\\t\\t<bndbox>\\n', '\\t\\t\\t<xmin>114</xmin>\\n', '\\t\\t\\t<ymin>0</ymin>\\n', '\\t\\t\\t<xmax>202</xmax>\\n', '\\t\\t\\t<ymax>99</ymax>\\n', '\\t\\t</bndbox>\\n', '\\t</object>\\n', '</annotation>\\n']\n"
]
},
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-15d1a5ca5d09>\u001b[0m in \u001b[0;36m<cell line: 19>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlines\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m# Split the line into individual coordinates as strings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mx_min_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_min_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_max_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_max_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;31m# Convert the coordinates to integers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 4, got 1)"
]
}
]
},
{
"cell_type": "code",
"source": [
"from PIL import Image, ImageDraw\n",
"import xml.etree.ElementTree as ET\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load the XML file\n",
"xml_file_path = \"/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/XML/img- (1176).xml\"\n",
"tree = ET.parse(xml_file_path)\n",
"root = tree.getroot()\n",
"\n",
"# Extract image path, width, and height\n",
"# image_path = root.find('path').text\n",
"image_path = '/content/MISAHUB/train/images/image_1176.png'\n",
"image_width = int(root.find('size/width').text)\n",
"image_height = int(root.find('size/height').text)\n",
"\n",
"# Load the image\n",
"image = Image.open(image_path)\n",
"\n",
"# Create a drawing context to draw on the image\n",
"draw = ImageDraw.Draw(image)\n",
"\n",
"# Loop through object elements in the XML\n",
"for obj in root.findall('object'):\n",
" # Extract bounding box coordinates\n",
" xmin = int(obj.find('bndbox/xmin').text)\n",
" ymin = int(obj.find('bndbox/ymin').text)\n",
" xmax = int(obj.find('bndbox/xmax').text)\n",
" ymax = int(obj.find('bndbox/ymax').text)\n",
" print(xmin, xmax, ymin, ymax)\n",
" # Draw the bounding box on the image\n",
" draw.rectangle([xmin, ymin, xmax, ymax], outline='red', width=2) # You can customize the color and width\n",
"\n",
"# Save or display the image with bounding boxes\n",
" # This will display the image with bounding boxes\n",
"plt.imshow(image)\n",
"plt.axis('off') # Turn off axis labels\n",
"plt.show()\n",
"# If you want to save the image with bounding boxes:\n",
"# image_with_boxes_path = 'path_to_save_image_with_boxes.png'\n",
"# image.save(image_with_boxes_path)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "0a12XbibTz2b",
"outputId": "0357ca0d-9bdc-4307-b001-e1fe98ab49dc"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"114 202 0 99\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"import xml.etree.ElementTree as ET\n",
"\n",
"# Directory containing XML files\n",
"xml_folder = \"/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/XML\"\n",
"\n",
"# Directory to save TXT files\n",
"txt_folder = \"/content/MISAHUB/WCEBleedGen/bleeding/Bounding boxes/TXT\"\n",
"\n",
"# Create the TXT folder if it doesn't exist\n",
"os.makedirs(txt_folder, exist_ok=True)\n",
"\n",
"# Iterate through XML files in the folder\n",
"for xml_filename in os.listdir(xml_folder):\n",
" if xml_filename.endswith(\".xml\"):\n",
" # Construct the full paths for XML and TXT files\n",
" xml_path = os.path.join(xml_folder, xml_filename)\n",
" txt_filename = os.path.splitext(xml_filename)[0] + \".txt\"\n",
" txt_path = os.path.join(txt_folder, txt_filename)\n",
"\n",
" # Parse the XML file\n",
" tree = ET.parse(xml_path)\n",
" root = tree.getroot()\n",
"\n",
" # Open the TXT file for writing\n",
" with open(txt_path, \"w\") as txt_file:\n",
" # Loop through object elements in the XML\n",
" for obj in root.findall('object'):\n",
" # Extract bounding box coordinates\n",
" xmin = int(obj.find('bndbox/xmin').text)\n",
" ymin = int(obj.find('bndbox/ymin').text)\n",
" xmax = int(obj.find('bndbox/xmax').text)\n",
" ymax = int(obj.find('bndbox/ymax').text)\n",
"\n",
" # Write the coordinates to the TXT file\n",
" txt_file.write(f\"{xmin} {ymin} {xmax} {ymax}\\n\")\n",
"\n",
" print(f\"Converted {xml_filename} to {txt_filename}\")\n",
"\n",
"print(\"Conversion complete.\")\n"
],
"metadata": {
"id": "KoqxZP0hVlm8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Detecting Bleeding edges from correctly classified images Using MaskedRCNN**"
],
"metadata": {
"id": "ofEh9VUI4Whn"
}
},
{
"cell_type": "markdown",
"source": [
"# **Load Dataset**"
],
"metadata": {
"id": "l-1hSEBY7Lfz"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import os\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torchvision\n",
"from torchvision import transforms as T\n",
"from torchvision.models.detection.faster_rcnn import FastRCNNPredictor\n",
"from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor"
],
"metadata": {
"id": "Mar3Hlyd4fpU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"images = sorted(os.listdir(\"/content/MISAHUB/WCEBleedGen/bleeding/Images\"))\n",
"masks = sorted(os.listdir(\"/content/MISAHUB/WCEBleedGen/bleeding/Annotations\"))"
],
"metadata": {
"id": "nvzuvpWA735-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"img = Image.open(\"/content/MISAHUB/WCEBleedGen/bleeding/Images/\" + images[0]).convert(\"RGB\")\n",
"mask = Image.open(\"/content/MISAHUB/WCEBleedGen/bleeding/Annotations/\" + masks[0])"
],
"metadata": {
"id": "qC3691af9Ccm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"img"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 241
},
"id": "aXRJmpTc9Cby",
"outputId": "42061b28-1748-4538-f559-d4e899108336"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<PIL.Image.Image image mode=RGB size=224x224>"
],
"image/png": "\n"
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"source": [
"mask"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 241
},
"id": "tKr1DPrv97_0",
"outputId": "6dd8acb7-ffcf-4624-b9f2-c0c68cffb481"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=1 size=224x224>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADgAQAAAAAyVldMAAAAMUlEQVR4nO3KsQ0AIBADsYiG/Tdis7AB9UvY7V0CAADMsdu2J1mvSxRFURT/jQAwwwUvcwZBlbFXyQAAAABJRU5ErkJggg==\n"
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"source": [
"np.unique(mask)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RBJndlyG-IDE",
"outputId": "cabf2d4a-aae6-4069-d430-fca081e35392"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([False, True])"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"source": [
"class CustDat(torch.utils.data.Dataset):\n",
" def __init__(self , images , masks):\n",
" self.imgs = images\n",
" self.masks = masks\n",
"\n",
" def __getitem__(self , idx):\n",
" img = Image.open(\"/content/MISAHUB/WCEBleedGen/bleeding/Images/\" + self.imgs[idx]).convert(\"RGB\")\n",
" mask = Image.open(\"/content/MISAHUB/WCEBleedGen/bleeding/Annotations/\" + self.masks[idx])\n",
" mask = np.array(mask)\n",
" obj_ids = np.unique(mask)\n",
" obj_ids = obj_ids[1:]\n",
" num_objs = len(obj_ids)\n",
" masks = np.zeros((num_objs , mask.shape[0] , mask.shape[1]))\n",
" for i in range(num_objs):\n",
" masks[i][mask == i+1] = True\n",
" boxes = []\n",
" for i in range(num_objs):\n",
" pos = np.where(masks[i])\n",
" xmin = np.min(pos[1])\n",
" xmax = np.max(pos[1])\n",
" ymin = np.min(pos[0])\n",
" ymax = np.max(pos[0])\n",
" boxes.append([xmin , ymin , xmax , ymax])\n",
" boxes = torch.as_tensor(boxes , dtype = torch.float32)\n",
" labels = torch.ones((num_objs,) , dtype = torch.int64)\n",
" masks = torch.as_tensor(masks , dtype = torch.uint8)\n",
"\n",
" target = {}\n",
" target[\"boxes\"] = boxes\n",
" target[\"labels\"] = labels\n",
" target[\"masks\"] = masks\n",
" return T.ToTensor()(img) , target\n",
"\n",
" def __len__(self):\n",
" return len(self.imgs)"
],
"metadata": {
"id": "zt0I70ZlDlkg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = torchvision.models.detection.maskrcnn_resnet50_fpn()\n",
"in_features = model.roi_heads.box_predictor.cls_score.in_features\n",
"model.roi_heads.box_predictor = FastRCNNPredictor(in_features , 2)\n",
"in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels\n",
"hidden_layer = 256\n",
"model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask , hidden_layer , 2)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Tuguet74D4WG",
"outputId": "f2cf517a-1d8b-4c5a-cb60-abfd174fdeb2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Downloading: \"https://download.pytorch.org/models/resnet50-0676ba61.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth\n",
"100%|██████████| 97.8M/97.8M [00:02<00:00, 51.1MB/s]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"transform = T.ToTensor()"
],
"metadata": {
"id": "VAAT8VL_D8IH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def custom_collate(data):\n",
" return data"
],
"metadata": {
"id": "aWDkpiaqD-5r"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Splitting"
],
"metadata": {
"id": "X3VaRrCcFYkG"
}
},
{
"cell_type": "code",
"source": [
"images = sorted(os.listdir(\"/content/MISAHUB/WCEBleedGen/bleeding/Images\"))\n",
"masks = sorted(os.listdir(\"/content/MISAHUB/WCEBleedGen/bleeding/Annotations\"))\n",
"num = int(0.9 * len(images))\n",
"num = num if num % 2 == 0 else num + 1\n",
"train_imgs_inds = np.random.choice(range(len(images)) , num , replace = False)\n",
"val_imgs_inds = np.setdiff1d(range(len(images)) , train_imgs_inds)\n",
"train_imgs = np.array(images)[train_imgs_inds]\n",
"val_imgs = np.array(images)[val_imgs_inds]\n",
"train_masks = np.array(masks)[train_imgs_inds]\n",
"val_masks = np.array(masks)[val_imgs_inds]\n"
],
"metadata": {
"id": "C_VzCMkhECEO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_dl = torch.utils.data.DataLoader(CustDat(train_imgs , train_masks) ,\n",
" batch_size = 2 ,\n",
" shuffle = True ,\n",
" collate_fn = custom_collate ,\n",
" num_workers = 1 ,\n",
" pin_memory = True if torch.cuda.is_available() else False)\n",
"val_dl = torch.utils.data.DataLoader(CustDat(val_imgs , val_masks) ,\n",
" batch_size = 2 ,\n",
" shuffle = True ,\n",
" collate_fn = custom_collate ,\n",
" num_workers = 1 ,\n",
" pin_memory = True if torch.cuda.is_available() else False)"
],
"metadata": {
"id": "xNkkXFm_EKqt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"device"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5udKwEr9ENdG",
"outputId": "cb32c909-56ac-4fe5-8878-0c0f4ccec80d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"source": [
"model.to(device)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VpOAXJ5REPgs",
"outputId": "3fc951f6-a0fa-4896-cc1a-95dc0b411686"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"MaskRCNN(\n",
" (transform): GeneralizedRCNNTransform(\n",
" Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
" Resize(min_size=(800,), max_size=1333, mode='bilinear')\n",
" )\n",
" (backbone): BackboneWithFPN(\n",
" (body): IntermediateLayerGetter(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): FrozenBatchNorm2d(64, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(64, eps=1e-05)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(64, eps=1e-05)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): FrozenBatchNorm2d(256, eps=1e-05)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(64, eps=1e-05)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(64, eps=1e-05)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(64, eps=1e-05)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(64, eps=1e-05)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): FrozenBatchNorm2d(512, eps=1e-05)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(128, eps=1e-05)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(1024, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): FrozenBatchNorm2d(1024, eps=1e-05)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(1024, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(1024, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(1024, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(1024, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(256, eps=1e-05)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(1024, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(2048, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): FrozenBatchNorm2d(2048, eps=1e-05)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(2048, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): FrozenBatchNorm2d(512, eps=1e-05)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): FrozenBatchNorm2d(2048, eps=1e-05)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" )\n",
" (fpn): FeaturePyramidNetwork(\n",
" (inner_blocks): ModuleList(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (2): Conv2dNormActivation(\n",
" (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" )\n",
" (layer_blocks): ModuleList(\n",
" (0-3): 4 x Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" )\n",
" )\n",
" (extra_blocks): LastLevelMaxPool()\n",
" )\n",
" )\n",
" (rpn): RegionProposalNetwork(\n",
" (anchor_generator): AnchorGenerator()\n",
" (head): RPNHead(\n",
" (conv): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))\n",
" (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" )\n",
" (roi_heads): RoIHeads(\n",
" (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)\n",
" (box_head): TwoMLPHead(\n",
" (fc6): Linear(in_features=12544, out_features=1024, bias=True)\n",
" (fc7): Linear(in_features=1024, out_features=1024, bias=True)\n",
" )\n",
" (box_predictor): FastRCNNPredictor(\n",
" (cls_score): Linear(in_features=1024, out_features=2, bias=True)\n",
" (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)\n",
" )\n",
" (mask_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(14, 14), sampling_ratio=2)\n",
" (mask_head): MaskRCNNHeads(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (2): Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (mask_predictor): MaskRCNNPredictor(\n",
" (conv5_mask): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))\n",
" (relu): ReLU(inplace=True)\n",
" (mask_fcn_logits): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"source": [
"params = [p for p in model.parameters() if p.requires_grad]"
],
"metadata": {
"id": "qObrhmmqERNz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)"
],
"metadata": {
"id": "IyzITiOhEVQz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Training"
],
"metadata": {
"id": "HQZ3bg9GFb0p"
}
},
{
"cell_type": "code",
"source": [
"all_train_losses = []\n",
"all_val_losses = []\n",
"flag = False\n",
"for epoch in range(10):\n",
" train_epoch_loss = 0\n",
" val_epoch_loss = 0\n",
" model.train()\n",
" for i , dt in enumerate(train_dl):\n",
" imgs = [dt[0][0].to(device) , dt[1][0].to(device)]\n",
" targ = [dt[0][1] , dt[1][1]]\n",
" targets = [{k: v.to(device) for k, v in t.items()} for t in targ]\n",
" loss = model(imgs , targets)\n",
" if not flag:\n",
" print(loss)\n",
" flag = True\n",
" losses = sum([l for l in loss.values()])\n",
" train_epoch_loss += losses.cpu().detach().numpy()\n",
" optimizer.zero_grad()\n",
" losses.backward()\n",
" optimizer.step()\n",
" all_train_losses.append(train_epoch_loss)\n",
" with torch.no_grad():\n",
" for j , dt in enumerate(val_dl):\n",
" imgs = [dt[0][0].to(device) , dt[1][0].to(device)]\n",
" targ = [dt[0][1] , dt[1][1]]\n",
" targets = [{k: v.to(device) for k, v in t.items()} for t in targ]\n",
" loss = model(imgs , targets)\n",
" losses = sum([l for l in loss.values()])\n",
" val_epoch_loss += losses.cpu().detach().numpy()\n",
" all_val_losses.append(val_epoch_loss)\n",
" print(epoch , \" \" , train_epoch_loss , \" \" , val_epoch_loss)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 722
},
"id": "yZuIwZdxEXDf",
"outputId": "148bcda7-cdb0-4bbc-c30f-542ae54dd98d"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'loss_classifier': tensor(0.6094, device='cuda:0', grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0303, device='cuda:0', grad_fn=<DivBackward0>), 'loss_mask': tensor(2.0514, device='cuda:0',\n",
" grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_objectness': tensor(0.6962, device='cuda:0',\n",
" grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.0181, device='cuda:0', grad_fn=<DivBackward0>)}\n"
]
},
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-26-20da855f622a>\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mval_epoch_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mimgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mtarg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 631\u001b[0m \u001b[0;31m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 632\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 633\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 634\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 635\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1343\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1344\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_task_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1345\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1346\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1347\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_data\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 1369\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1370\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1371\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1372\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1373\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_utils.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0;31m# instantiate since we don't know how to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 644\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mexception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 645\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 646\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Caught ValueError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py\", line 308, in _worker_loop\n data = fetcher.fetch(index)\n File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py\", line 51, in fetch\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py\", line 51, in <listcomp>\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"<ipython-input-16-e3a037f55269>\", line 19, in __getitem__\n xmin = np.min(pos[1])\n File \"<__array_function__ internals>\", line 180, in amin\n File \"/usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py\", line 2918, in amin\n return _wrapreduction(a, np.minimum, 'min', axis, None, out,\n File \"/usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py\", line 86, in _wrapreduction\n return ufunc.reduce(obj, axis, dtype, out, **passkwargs)\nValueError: zero-size array to reduction operation minimum which has no identity\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Evaluation"
],
"metadata": {
"id": "nKPEqMZJFepe"
}
},
{
"cell_type": "code",
"source": [
"plt.plot(all_train_losses)"
],
"metadata": {
"id": "Cs2JrrbUEcKr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"plt.plot(all_val_losses)"
],
"metadata": {
"id": "_7hcU_KnE7Ed"
},
"execution_count": null,
"outputs": []
}
]
}