--- a +++ b/Evalution.ipynb @@ -0,0 +1,520 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iXo6Nj7M5GK8" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import os\n", + "import glob\n", + "import time\n", + "import pickle\n", + "import sys\n", + "sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n", + "sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Patha\n", + "from PIL import Image\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from data import LungDataset, blend, Pad, Crop, Resize\n", + "from data2 import LungDataset2, blend, Pad, Crop, Resize\n", + "\n", + "from OurModel import CxlNet\n", + "\n", + "from metrics import jaccard, dice,get_accuracy, get_sensitivity, get_specificity" + ] + }, + { + "cell_type": "code", + "source": [ + "in_channels=1\n", + "out_channels=2\n", + "batch_norm=True\n", + "upscale_mode=\"bilinear\"\n", + "image_size=512\n", + "def selectModel():\n", + " return CxlNet(\n", + " in_channels=in_channels,\n", + " out_channels=out_channels,\n", + " batch_norm=batch_norm,\n", + " upscale_mode=upscale_mode,\n", + " image_size=image_size)" + ], + "metadata": { + "id": "BGn0Pjmb5gRA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "dataset_name=\"dataset\"\n", + "dataset_types={\"dataset\":\"png\",\"CT\":\"jpg\"}\n", + "dataset_type=dataset_types[dataset_name]\n", + "print(dataset_type)\n", + "image_size=512\n", + "split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n", + "list_data_file = \"/content/drive/MyDrive/Batoul_Code/list_data.pk\"\n", + "version=\"UNet\"\n", + "approach=\"contour\"\n", + "model = selectModel()\n", + "\n", + "base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "device\n", + "\n", + "\n", + "\n", + "\n", + "data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n", + "origins_folder = data_folder / \"images\"\n", + "masks_folder = data_folder / \"masks\"\n", + "masks_contour_folder = data_folder / \"masks_contour\"\n", + "masks_folder =masks_contour_folder\n", + "models_folder = Path(base_path+\"models\")\n", + "images_folder = Path(base_path+\"images\")\n" + ], + "metadata": { + "id": "5Sp6Ga1y50He" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "models_folder = Path(base_path+\"models\")\n", + "model_name = \"unet-6v.pt\"\n", + "model_name=\"ournet_\"+version+\".pt\"\n", + "print(model_name)\n", + "model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", + "model.to(device)\n", + "model.eval()\n", + "\n", + "\n", + "test_loss = 0.0\n", + "test_jaccard = 0.0\n", + "test_dice = 0.0\n", + "test_accuracy=0.0\n", + "test_sensitivity=0.0\n", + "test_specificity=0.0\n", + "batch_size = 4\n", + "\n", + "if os.path.isfile(list_data_file):\n", + " with open(list_data_file, \"rb\") as f:\n", + " list_data = pickle.load(f)\n", + " origins_list=list_data[0]\n", + " masks_list=list_data[1]\n", + "else:\n", + " origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n", + " masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n", + " with open(list_data_file, \"wb\") as f:\n", + " pickle.dump([origins_list,masks_list], f)\n", + "\n", + "\n", + "#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n", + "#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n", + "\n", + "\n", + "origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n", + "\n", + "\n", + "\n", + "if os.path.isfile(split_file):\n", + " with open(split_file, \"rb\") as f:\n", + " splits = pickle.load(f)\n", + "else:\n", + " splits = {}\n", + " splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n", + " splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n", + " with open(split_file, \"wb\") as f:\n", + " pickle.dump(splits, f)\n", + "\n", + "val_test_transforms = torchvision.transforms.Compose([\n", + " Resize((image_size, image_size)),\n", + "])\n", + "\n", + "if dataset_name!=\"dataset\":\n", + " train_transforms = torchvision.transforms.Compose([\n", + " Pad(200),\n", + " Crop(300),\n", + " val_test_transforms,\n", + " ])\n", + " datasets = {x: LungDataset2(\n", + " splits[x],\n", + " origins_folder,\n", + " masks_folder,\n", + " train_transforms if x == \"train\" else val_test_transforms,\n", + " dataset_type=dataset_type\n", + " ) for x in [\"train\", \"test\", \"val\"]}\n", + "else:\n", + " train_transforms = torchvision.transforms.Compose([\n", + " Pad(200),\n", + " Crop(300),\n", + " val_test_transforms,])\n", + "\n", + " datasets = {x: LungDataset(\n", + " splits[x],\n", + " origins_folder,\n", + " masks_folder,\n", + " train_transforms if x == \"train\" else val_test_transforms,\n", + " dataset_type=dataset_type\n", + " ) for x in [\"train\", \"test\", \"val\"]}\n", + "\n", + "num_samples = 9\n", + "phase = \"test\"\n", + "print(len(datasets[phase]))\n", + "\n", + "dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size) for x in [\"train\", \"test\", \"val\"]}\n", + "\n", + "for origins, masks in dataloaders[\"test\"]:\n", + " num = origins.size(0)\n", + "\n", + " origins = origins.to(device)\n", + " masks = masks.to(device)\n", + "\n", + " with torch.no_grad():\n", + " outs = model(origins)\n", + " softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", + " test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n", + " outs = torch.argmax(softmax, dim=1)\n", + " outs = outs.float()\n", + " masks = masks.float()\n", + " test_jaccard += jaccard(masks, outs).item() * num\n", + " test_dice += dice(masks, outs).item() * num\n", + " test_accuracy += get_accuracy(masks, outs) * num\n", + " test_sensitivity += get_sensitivity(masks, outs) * num\n", + " test_specificity += get_specificity(masks, outs) * num\n", + " print(\".\", end=\"\")\n", + "\n", + "test_loss = test_loss / len(datasets[\"test\"])\n", + "test_jaccard = test_jaccard / len(datasets[\"test\"])\n", + "test_dice = test_dice / len(datasets[\"test\"])\n", + "test_accuracy = test_accuracy / len(datasets[\"test\"])\n", + "print()\n", + "print(f\"avg test loss: {test_loss}\")\n", + "print(f\"avg test jaccard: {test_jaccard}\")\n", + "print(f\"avg test dice: {test_dice}\")\n", + "print(f\"avg test accuracy: {test_accuracy}\")\n", + "print(f\"avg test sensitivity: {test_sensitivity}\")\n", + "print(f\"avg test specificity: {test_specificity}\")\n", + "\n", + "\n", + "\n", + "subset = torch.utils.data.Subset(\n", + " datasets[phase],\n", + " np.random.randint(0, len(datasets[phase]), num_samples)\n", + ")\n", + "random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=2)\n", + "plt.figure(figsize=(20, 25))\n", + "\n", + "for idx, (origin, mask) in enumerate(random_samples_loader):\n", + " plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n", + "\n", + " origin = origin.to(device)\n", + " mask = mask.to(device)\n", + "\n", + " with torch.no_grad():\n", + " out = model(origin)\n", + " softmax = torch.nn.functional.log_softmax(out, dim=1)\n", + " out = torch.argmax(softmax, dim=1)\n", + "\n", + " jaccard_score = jaccard(mask.float(), out.float()).item()\n", + " dice_score = dice(mask.float(), out.float()).item()\n", + "\n", + " origin = origin[0].to(\"cpu\")\n", + " out = out[0].to(\"cpu\")\n", + " mask = mask[0].to(\"cpu\")\n", + " #plt.imshow(np.array(blend(origin, mask, out)))\n", + " plt.imshow(np.array(blend(origin, out, out)))\n", + " plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n", + " print(\".\", end=\"\")\n", + "\n", + "plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n", + "plt.show()\n", + "print()\n", + "print(\"red area - predict\")\n", + "print(\"green area - ground truth\")\n", + "print(\"yellow area - intersection\")\n", + "\n", + "\n", + "model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", + "model.to(device)\n", + "model.eval()\n", + "\n", + "device\n", + "\n", + "#%%\n", + "\n", + "origin_filename = base_path+ f\"input/{dataset_name}/images/ID00015637202177877247924_110.jpg\"\n", + "#origin_filename=base_path + \"external_samples/1.jpg\"\n", + "\n", + "origin = Image.open(origin_filename).convert(\"P\")\n", + "origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n", + "origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n", + "\n", + "with torch.no_grad():\n", + " origin = torch.stack([origin])\n", + " origin = origin.to(device)\n", + " out = model(origin)\n", + " softmax = torch.nn.functional.log_softmax(out, dim=1)\n", + " out = torch.argmax(softmax, dim=1)\n", + "\n", + " origin = origin[0].to(\"cpu\")\n", + " out = out[0].to(\"cpu\")\n", + "\n", + "\n", + "plt.figure(figsize=(20, 10))\n", + "\n", + "pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", + "\n", + "plt.subplot(1, 2, 1)\n", + "plt.title(\"origin image\")\n", + "plt.imshow(np.array(pil_origin))\n", + "plt.show()\n", + "plt.subplot(1, 2, 2)\n", + "plt.title(\"blended origin + predict\")\n", + "plt.imshow(np.array(blend(origin, out)))\n", + "plt.show()\n" + ], + "metadata": { + "id": "jFpE8AwG5-Nm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "models_folder = Path(base_path+\"models\")\n", + "\n", + "\n", + "\n", + "test_loss = 0.0\n", + "test_jaccard = 0.0\n", + "test_dice = 0.0\n", + "\n", + "batch_size = 4\n", + "\n", + "if os.path.isfile(list_data_file):\n", + " with open(list_data_file, \"rb\") as f:\n", + " list_data = pickle.load(f)\n", + " origins_list=list_data[0]\n", + " masks_list=list_data[1]\n", + "else:\n", + " origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n", + " masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n", + " with open(list_data_file, \"wb\") as f:\n", + " pickle.dump([origins_list,masks_list], f)\n", + "\n", + "\n", + "#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n", + "#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n", + "\n", + "\n", + "origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n", + "\n", + "\n", + "\n", + "if os.path.isfile(split_file):\n", + " with open(split_file, \"rb\") as f:\n", + " splits = pickle.load(f)\n", + "else:\n", + " splits = {}\n", + " splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n", + " splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n", + " with open(split_file, \"wb\") as f:\n", + " pickle.dump(splits, f)\n", + "\n", + "val_test_transforms = torchvision.transforms.Compose([\n", + " Resize((image_size, image_size)),\n", + "])\n", + "\n", + "if dataset_name!=\"dataset\":\n", + " train_transforms = torchvision.transforms.Compose([\n", + " #Pad(200),\n", + " #Crop(300),\n", + " #val_test_transforms,\n", + " ])\n", + " datasets = {x: LungDataset2(\n", + " splits[x],\n", + " origins_folder,\n", + " masks_folder,\n", + " train_transforms if x == \"train\" else val_test_transforms,\n", + " dataset_type=dataset_type\n", + " ) for x in [\"train\", \"test\", \"val\"]}\n", + "else:\n", + " train_transforms = torchvision.transforms.Compose([\n", + " Pad(200),\n", + " Crop(300),\n", + " val_test_transforms,])\n", + "\n", + " datasets = {x: LungDataset(\n", + " splits[x],\n", + " origins_folder,\n", + " masks_folder,\n", + " train_transforms if x == \"train\" else val_test_transforms,\n", + " dataset_type=dataset_type\n", + " ) for x in [\"train\", \"test\", \"val\"]}\n", + "\n", + "\n", + "def mask_to_class_rgb1(mask):\n", + " #print('----mask->rgb----')\n", + " mask = torch.from_numpy(np.array(mask))\n", + " mask = torch.squeeze(mask) # remove 1\n", + "\n", + " class_mask = mask\n", + "\n", + " class_mask = class_mask.permute(2, 0, 1).contiguous()\n", + " h, w = class_mask.shape[1], class_mask.shape[2]\n", + " mask_out = torch.zeros((h, w))\n", + "\n", + " threshold=200\n", + " for i in range(0,3):\n", + " class_mask[i][class_mask[i] < threshold] = 0\n", + "\n", + " for i in range(2, 3):\n", + " mask_out[class_mask[i] >= threshold]=1\n", + " return mask_out\n", + "\n", + "\n", + "def mask_to_class_rgb(mask):\n", + " #print('----mask->rgb----')\n", + " mask = torch.from_numpy(np.array(mask))\n", + " mask = torch.squeeze(mask) # remove 1\n", + "\n", + " class_mask = mask\n", + "\n", + " class_mask = class_mask.permute(2, 0, 1).contiguous()\n", + " h, w = class_mask.shape[1], class_mask.shape[2]\n", + " mask_out = torch.zeros((h, w))\n", + "\n", + " threshold=200\n", + " for i in range(0,3):\n", + " class_mask[i][class_mask[i] < threshold] = 0\n", + "\n", + " for i in range(2, 3):\n", + " mask_out[class_mask[i] >= threshold]=1\n", + " return mask_out\n", + "\n", + "def getitem2(path):\n", + " mask = Image.open(path)\n", + " mask = mask_to_class_rgb(mask)\n", + " mask=mask.long()\n", + " #mask = (torch.tensor(mask) > 128).long()\n", + " return mask\n", + "\n", + "def getitem1(path):\n", + " mask = Image.open(path)\n", + " mask = mask.resize((image_size,image_size))\n", + " mask = np.array(mask)\n", + " mask = (torch.tensor(mask) > 128).long()\n", + " return mask\n", + "\n", + "\n", + "idx=1\n", + "phase = \"test\"\n", + "fig = plt.figure(figsize=(20, 10))\n", + "input=0\n", + "if dataset_name!=\"dataset\":\n", + " samples=[\"ID00015637202177877247924_110.jpg\",\n", + " \"ID00009637202177434476278_173.jpg\",\n", + " \"ID00009637202177434476278_316.jpg\",\n", + " \"ID00009637202177434476278_204.jpg\",]\n", + " masks = [mask_name.replace(\"_\", \"_mask_\").replace(\"images\", \"masks\") for mask_name in samples]\n", + "\n", + "else:\n", + " samples=[\"CHNCXR_0060_0.png\",\n", + " \"CHNCXR_0074_0.png\",\n", + " \"CHNCXR_0129_0.png\",\n", + " \"CHNCXR_0167_0.png\",]\n", + " masks = [mask_name.replace(\"_0.png\", \"_0_mask.png\").replace(\"images\", \"masks\") for mask_name in samples]\n", + "\n", + "\n", + "samples=[base_path + f\"input/{dataset_name}/images/\"+ sample_name for sample_name in samples]\n", + "masks=[base_path + f\"input/{dataset_name}/masks/\"+ mask_name for mask_name in masks]\n", + "models=[\"ResNetDUCHDC\",\"OueNetNew3\",\"NestedUNet\",\"ResNetDUC\",\"FCN_GCN\",\"SegNet2\",\"UNet\"]\n", + "\n", + "for input in range(0,len(samples)) :\n", + " for m in range(0,len(models)):\n", + " origin_filename = samples[input]\n", + " origin = Image.open(origin_filename).convert(\"P\")\n", + " origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n", + " origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n", + " if dataset_name!=\"dataset\":\n", + " mask= getitem2(masks[input])\n", + " else:\n", + " mask= getitem1(masks[input])\n", + " version=models[m]\n", + " if dataset_name!=\"dataset\":\n", + " version=version+\"_\"+dataset_name\n", + " model = selectModel(models[m])\n", + " model_name=\"ournet_\"+version+\".pt\"\n", + " model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", + " model.to(device)\n", + " with torch.no_grad():\n", + " origin = torch.stack([origin])\n", + " origin = origin.to(device)\n", + " out = model(origin)\n", + " softmax = torch.nn.functional.log_softmax(out, dim=1)\n", + " out = torch.argmax(softmax, dim=1)\n", + "\n", + " origin = origin[0].to(\"cpu\")\n", + " out = out[0].to(\"cpu\")\n", + "\n", + " pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", + " plt.subplots_adjust(hspace=0)\n", + " if m==0:\n", + " ax=fig.add_subplot(len(samples), len(models)+2,idx)\n", + " ax.set_axis_off()\n", + " #plt.title(\"origin image\")\n", + " plt.imshow(np.array(pil_origin))\n", + " idx=idx+1\n", + " ax=fig.add_subplot(len(samples), len(models)+2,idx)\n", + " ax.set_axis_off()\n", + " plt.imshow(np.array(blend(origin, mask,amount=0.4)))\n", + " idx=idx+1\n", + " ax=fig.add_subplot(len(samples), len(models)+2,idx)\n", + " ax.set_axis_off()\n", + " #plt.title(\"blended origin + predict\")\n", + " plt.imshow(np.array(blend(origin, out,amount=0.5)))\n", + " #plt.savefig(images_folder / f\"results/{version} {input}\", bbox_inches='tight')\n", + " idx=idx+1\n", + "\n", + "plt.show()\n" + ], + "metadata": { + "id": "ksO_-uVR6V4F" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file