520 lines (520 with data), 20.9 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iXo6Nj7M5GK8"
},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"import os\n",
"import glob\n",
"import time\n",
"import pickle\n",
"import sys\n",
"sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n",
"sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Patha\n",
"from PIL import Image\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from data import LungDataset, blend, Pad, Crop, Resize\n",
"from data2 import LungDataset2, blend, Pad, Crop, Resize\n",
"\n",
"from OurModel import CxlNet\n",
"\n",
"from metrics import jaccard, dice,get_accuracy, get_sensitivity, get_specificity"
]
},
{
"cell_type": "code",
"source": [
"in_channels=1\n",
"out_channels=2\n",
"batch_norm=True\n",
"upscale_mode=\"bilinear\"\n",
"image_size=512\n",
"def selectModel():\n",
" return CxlNet(\n",
" in_channels=in_channels,\n",
" out_channels=out_channels,\n",
" batch_norm=batch_norm,\n",
" upscale_mode=upscale_mode,\n",
" image_size=image_size)"
],
"metadata": {
"id": "BGn0Pjmb5gRA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dataset_name=\"dataset\"\n",
"dataset_types={\"dataset\":\"png\",\"CT\":\"jpg\"}\n",
"dataset_type=dataset_types[dataset_name]\n",
"print(dataset_type)\n",
"image_size=512\n",
"split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n",
"list_data_file = \"/content/drive/MyDrive/Batoul_Code/list_data.pk\"\n",
"version=\"UNet\"\n",
"approach=\"contour\"\n",
"model = selectModel()\n",
"\n",
"base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"device\n",
"\n",
"\n",
"\n",
"\n",
"data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n",
"origins_folder = data_folder / \"images\"\n",
"masks_folder = data_folder / \"masks\"\n",
"masks_contour_folder = data_folder / \"masks_contour\"\n",
"masks_folder =masks_contour_folder\n",
"models_folder = Path(base_path+\"models\")\n",
"images_folder = Path(base_path+\"images\")\n"
],
"metadata": {
"id": "5Sp6Ga1y50He"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"models_folder = Path(base_path+\"models\")\n",
"model_name = \"unet-6v.pt\"\n",
"model_name=\"ournet_\"+version+\".pt\"\n",
"print(model_name)\n",
"model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
"model.to(device)\n",
"model.eval()\n",
"\n",
"\n",
"test_loss = 0.0\n",
"test_jaccard = 0.0\n",
"test_dice = 0.0\n",
"test_accuracy=0.0\n",
"test_sensitivity=0.0\n",
"test_specificity=0.0\n",
"batch_size = 4\n",
"\n",
"if os.path.isfile(list_data_file):\n",
" with open(list_data_file, \"rb\") as f:\n",
" list_data = pickle.load(f)\n",
" origins_list=list_data[0]\n",
" masks_list=list_data[1]\n",
"else:\n",
" origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n",
" masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n",
" with open(list_data_file, \"wb\") as f:\n",
" pickle.dump([origins_list,masks_list], f)\n",
"\n",
"\n",
"#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n",
"#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n",
"\n",
"\n",
"origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n",
"\n",
"\n",
"\n",
"if os.path.isfile(split_file):\n",
" with open(split_file, \"rb\") as f:\n",
" splits = pickle.load(f)\n",
"else:\n",
" splits = {}\n",
" splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n",
" splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n",
" with open(split_file, \"wb\") as f:\n",
" pickle.dump(splits, f)\n",
"\n",
"val_test_transforms = torchvision.transforms.Compose([\n",
" Resize((image_size, image_size)),\n",
"])\n",
"\n",
"if dataset_name!=\"dataset\":\n",
" train_transforms = torchvision.transforms.Compose([\n",
" Pad(200),\n",
" Crop(300),\n",
" val_test_transforms,\n",
" ])\n",
" datasets = {x: LungDataset2(\n",
" splits[x],\n",
" origins_folder,\n",
" masks_folder,\n",
" train_transforms if x == \"train\" else val_test_transforms,\n",
" dataset_type=dataset_type\n",
" ) for x in [\"train\", \"test\", \"val\"]}\n",
"else:\n",
" train_transforms = torchvision.transforms.Compose([\n",
" Pad(200),\n",
" Crop(300),\n",
" val_test_transforms,])\n",
"\n",
" datasets = {x: LungDataset(\n",
" splits[x],\n",
" origins_folder,\n",
" masks_folder,\n",
" train_transforms if x == \"train\" else val_test_transforms,\n",
" dataset_type=dataset_type\n",
" ) for x in [\"train\", \"test\", \"val\"]}\n",
"\n",
"num_samples = 9\n",
"phase = \"test\"\n",
"print(len(datasets[phase]))\n",
"\n",
"dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size) for x in [\"train\", \"test\", \"val\"]}\n",
"\n",
"for origins, masks in dataloaders[\"test\"]:\n",
" num = origins.size(0)\n",
"\n",
" origins = origins.to(device)\n",
" masks = masks.to(device)\n",
"\n",
" with torch.no_grad():\n",
" outs = model(origins)\n",
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
" test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n",
" outs = torch.argmax(softmax, dim=1)\n",
" outs = outs.float()\n",
" masks = masks.float()\n",
" test_jaccard += jaccard(masks, outs).item() * num\n",
" test_dice += dice(masks, outs).item() * num\n",
" test_accuracy += get_accuracy(masks, outs) * num\n",
" test_sensitivity += get_sensitivity(masks, outs) * num\n",
" test_specificity += get_specificity(masks, outs) * num\n",
" print(\".\", end=\"\")\n",
"\n",
"test_loss = test_loss / len(datasets[\"test\"])\n",
"test_jaccard = test_jaccard / len(datasets[\"test\"])\n",
"test_dice = test_dice / len(datasets[\"test\"])\n",
"test_accuracy = test_accuracy / len(datasets[\"test\"])\n",
"print()\n",
"print(f\"avg test loss: {test_loss}\")\n",
"print(f\"avg test jaccard: {test_jaccard}\")\n",
"print(f\"avg test dice: {test_dice}\")\n",
"print(f\"avg test accuracy: {test_accuracy}\")\n",
"print(f\"avg test sensitivity: {test_sensitivity}\")\n",
"print(f\"avg test specificity: {test_specificity}\")\n",
"\n",
"\n",
"\n",
"subset = torch.utils.data.Subset(\n",
" datasets[phase],\n",
" np.random.randint(0, len(datasets[phase]), num_samples)\n",
")\n",
"random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=2)\n",
"plt.figure(figsize=(20, 25))\n",
"\n",
"for idx, (origin, mask) in enumerate(random_samples_loader):\n",
" plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n",
"\n",
" origin = origin.to(device)\n",
" mask = mask.to(device)\n",
"\n",
" with torch.no_grad():\n",
" out = model(origin)\n",
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
" out = torch.argmax(softmax, dim=1)\n",
"\n",
" jaccard_score = jaccard(mask.float(), out.float()).item()\n",
" dice_score = dice(mask.float(), out.float()).item()\n",
"\n",
" origin = origin[0].to(\"cpu\")\n",
" out = out[0].to(\"cpu\")\n",
" mask = mask[0].to(\"cpu\")\n",
" #plt.imshow(np.array(blend(origin, mask, out)))\n",
" plt.imshow(np.array(blend(origin, out, out)))\n",
" plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n",
" print(\".\", end=\"\")\n",
"\n",
"plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n",
"plt.show()\n",
"print()\n",
"print(\"red area - predict\")\n",
"print(\"green area - ground truth\")\n",
"print(\"yellow area - intersection\")\n",
"\n",
"\n",
"model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
"model.to(device)\n",
"model.eval()\n",
"\n",
"device\n",
"\n",
"#%%\n",
"\n",
"origin_filename = base_path+ f\"input/{dataset_name}/images/ID00015637202177877247924_110.jpg\"\n",
"#origin_filename=base_path + \"external_samples/1.jpg\"\n",
"\n",
"origin = Image.open(origin_filename).convert(\"P\")\n",
"origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n",
"origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n",
"\n",
"with torch.no_grad():\n",
" origin = torch.stack([origin])\n",
" origin = origin.to(device)\n",
" out = model(origin)\n",
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
" out = torch.argmax(softmax, dim=1)\n",
"\n",
" origin = origin[0].to(\"cpu\")\n",
" out = out[0].to(\"cpu\")\n",
"\n",
"\n",
"plt.figure(figsize=(20, 10))\n",
"\n",
"pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.title(\"origin image\")\n",
"plt.imshow(np.array(pil_origin))\n",
"plt.show()\n",
"plt.subplot(1, 2, 2)\n",
"plt.title(\"blended origin + predict\")\n",
"plt.imshow(np.array(blend(origin, out)))\n",
"plt.show()\n"
],
"metadata": {
"id": "jFpE8AwG5-Nm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"models_folder = Path(base_path+\"models\")\n",
"\n",
"\n",
"\n",
"test_loss = 0.0\n",
"test_jaccard = 0.0\n",
"test_dice = 0.0\n",
"\n",
"batch_size = 4\n",
"\n",
"if os.path.isfile(list_data_file):\n",
" with open(list_data_file, \"rb\") as f:\n",
" list_data = pickle.load(f)\n",
" origins_list=list_data[0]\n",
" masks_list=list_data[1]\n",
"else:\n",
" origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n",
" masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n",
" with open(list_data_file, \"wb\") as f:\n",
" pickle.dump([origins_list,masks_list], f)\n",
"\n",
"\n",
"#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n",
"#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n",
"\n",
"\n",
"origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n",
"\n",
"\n",
"\n",
"if os.path.isfile(split_file):\n",
" with open(split_file, \"rb\") as f:\n",
" splits = pickle.load(f)\n",
"else:\n",
" splits = {}\n",
" splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n",
" splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n",
" with open(split_file, \"wb\") as f:\n",
" pickle.dump(splits, f)\n",
"\n",
"val_test_transforms = torchvision.transforms.Compose([\n",
" Resize((image_size, image_size)),\n",
"])\n",
"\n",
"if dataset_name!=\"dataset\":\n",
" train_transforms = torchvision.transforms.Compose([\n",
" #Pad(200),\n",
" #Crop(300),\n",
" #val_test_transforms,\n",
" ])\n",
" datasets = {x: LungDataset2(\n",
" splits[x],\n",
" origins_folder,\n",
" masks_folder,\n",
" train_transforms if x == \"train\" else val_test_transforms,\n",
" dataset_type=dataset_type\n",
" ) for x in [\"train\", \"test\", \"val\"]}\n",
"else:\n",
" train_transforms = torchvision.transforms.Compose([\n",
" Pad(200),\n",
" Crop(300),\n",
" val_test_transforms,])\n",
"\n",
" datasets = {x: LungDataset(\n",
" splits[x],\n",
" origins_folder,\n",
" masks_folder,\n",
" train_transforms if x == \"train\" else val_test_transforms,\n",
" dataset_type=dataset_type\n",
" ) for x in [\"train\", \"test\", \"val\"]}\n",
"\n",
"\n",
"def mask_to_class_rgb1(mask):\n",
" #print('----mask->rgb----')\n",
" mask = torch.from_numpy(np.array(mask))\n",
" mask = torch.squeeze(mask) # remove 1\n",
"\n",
" class_mask = mask\n",
"\n",
" class_mask = class_mask.permute(2, 0, 1).contiguous()\n",
" h, w = class_mask.shape[1], class_mask.shape[2]\n",
" mask_out = torch.zeros((h, w))\n",
"\n",
" threshold=200\n",
" for i in range(0,3):\n",
" class_mask[i][class_mask[i] < threshold] = 0\n",
"\n",
" for i in range(2, 3):\n",
" mask_out[class_mask[i] >= threshold]=1\n",
" return mask_out\n",
"\n",
"\n",
"def mask_to_class_rgb(mask):\n",
" #print('----mask->rgb----')\n",
" mask = torch.from_numpy(np.array(mask))\n",
" mask = torch.squeeze(mask) # remove 1\n",
"\n",
" class_mask = mask\n",
"\n",
" class_mask = class_mask.permute(2, 0, 1).contiguous()\n",
" h, w = class_mask.shape[1], class_mask.shape[2]\n",
" mask_out = torch.zeros((h, w))\n",
"\n",
" threshold=200\n",
" for i in range(0,3):\n",
" class_mask[i][class_mask[i] < threshold] = 0\n",
"\n",
" for i in range(2, 3):\n",
" mask_out[class_mask[i] >= threshold]=1\n",
" return mask_out\n",
"\n",
"def getitem2(path):\n",
" mask = Image.open(path)\n",
" mask = mask_to_class_rgb(mask)\n",
" mask=mask.long()\n",
" #mask = (torch.tensor(mask) > 128).long()\n",
" return mask\n",
"\n",
"def getitem1(path):\n",
" mask = Image.open(path)\n",
" mask = mask.resize((image_size,image_size))\n",
" mask = np.array(mask)\n",
" mask = (torch.tensor(mask) > 128).long()\n",
" return mask\n",
"\n",
"\n",
"idx=1\n",
"phase = \"test\"\n",
"fig = plt.figure(figsize=(20, 10))\n",
"input=0\n",
"if dataset_name!=\"dataset\":\n",
" samples=[\"ID00015637202177877247924_110.jpg\",\n",
" \"ID00009637202177434476278_173.jpg\",\n",
" \"ID00009637202177434476278_316.jpg\",\n",
" \"ID00009637202177434476278_204.jpg\",]\n",
" masks = [mask_name.replace(\"_\", \"_mask_\").replace(\"images\", \"masks\") for mask_name in samples]\n",
"\n",
"else:\n",
" samples=[\"CHNCXR_0060_0.png\",\n",
" \"CHNCXR_0074_0.png\",\n",
" \"CHNCXR_0129_0.png\",\n",
" \"CHNCXR_0167_0.png\",]\n",
" masks = [mask_name.replace(\"_0.png\", \"_0_mask.png\").replace(\"images\", \"masks\") for mask_name in samples]\n",
"\n",
"\n",
"samples=[base_path + f\"input/{dataset_name}/images/\"+ sample_name for sample_name in samples]\n",
"masks=[base_path + f\"input/{dataset_name}/masks/\"+ mask_name for mask_name in masks]\n",
"models=[\"ResNetDUCHDC\",\"OueNetNew3\",\"NestedUNet\",\"ResNetDUC\",\"FCN_GCN\",\"SegNet2\",\"UNet\"]\n",
"\n",
"for input in range(0,len(samples)) :\n",
" for m in range(0,len(models)):\n",
" origin_filename = samples[input]\n",
" origin = Image.open(origin_filename).convert(\"P\")\n",
" origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n",
" origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n",
" if dataset_name!=\"dataset\":\n",
" mask= getitem2(masks[input])\n",
" else:\n",
" mask= getitem1(masks[input])\n",
" version=models[m]\n",
" if dataset_name!=\"dataset\":\n",
" version=version+\"_\"+dataset_name\n",
" model = selectModel(models[m])\n",
" model_name=\"ournet_\"+version+\".pt\"\n",
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
" model.to(device)\n",
" with torch.no_grad():\n",
" origin = torch.stack([origin])\n",
" origin = origin.to(device)\n",
" out = model(origin)\n",
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
" out = torch.argmax(softmax, dim=1)\n",
"\n",
" origin = origin[0].to(\"cpu\")\n",
" out = out[0].to(\"cpu\")\n",
"\n",
" pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
" plt.subplots_adjust(hspace=0)\n",
" if m==0:\n",
" ax=fig.add_subplot(len(samples), len(models)+2,idx)\n",
" ax.set_axis_off()\n",
" #plt.title(\"origin image\")\n",
" plt.imshow(np.array(pil_origin))\n",
" idx=idx+1\n",
" ax=fig.add_subplot(len(samples), len(models)+2,idx)\n",
" ax.set_axis_off()\n",
" plt.imshow(np.array(blend(origin, mask,amount=0.4)))\n",
" idx=idx+1\n",
" ax=fig.add_subplot(len(samples), len(models)+2,idx)\n",
" ax.set_axis_off()\n",
" #plt.title(\"blended origin + predict\")\n",
" plt.imshow(np.array(blend(origin, out,amount=0.5)))\n",
" #plt.savefig(images_folder / f\"results/{version} {input}\", bbox_inches='tight')\n",
" idx=idx+1\n",
"\n",
"plt.show()\n"
],
"metadata": {
"id": "ksO_-uVR6V4F"
},
"execution_count": null,
"outputs": []
}
]
}