442 lines (442 with data), 17.8 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fqPJUT863Ps2"
},
"outputs": [],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"#!unzip '/content/drive/MyDrive/Batoul_Code/input/CT/images.zip' -d '/content/drive/MyDrive/Batoul_Code/input/CT'"
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torchvision\n",
"import os\n",
"import glob\n",
"import time\n",
"import pickle\n",
"import sys\n",
"sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n",
"sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"from data import LungDataset, blend, Pad, Crop, Resize\n",
"from OurModel import CxlNet\n",
"from metrics import jaccard, dice"
],
"metadata": {
"id": "N-IZfgid3Xwc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"in_channels=1\n",
"out_channels=2\n",
"batch_norm=True\n",
"upscale_mode=\"bilinear\"\n",
"image_size=512\n",
"def selectModel():\n",
" return CxlNet(\n",
" in_channels=in_channels,\n",
" out_channels=out_channels,\n",
" batch_norm=batch_norm,\n",
" upscale_mode=upscale_mode,\n",
" image_size=image_size)"
],
"metadata": {
"id": "J349vBjr31Ir"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dataset_name=\"dataset\"\n",
"dataset_type=\"png\"\n",
"split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n",
"version=\"CxlNet\"\n",
"approach=\"contour\"\n",
"model = selectModel()\n",
"\n",
"base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"device\n",
"\n",
"data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n",
"origins_folder = data_folder / \"images\"\n",
"masks_folder = data_folder / \"masks\"\n",
"masks_contour_folder = data_folder / \"masks_contour\"\n",
"models_folder = Path(base_path+\"models\")\n",
"images_folder = Path(base_path+\"images\")\n",
"\n",
"\n"
],
"metadata": {
"id": "MhOSosnw4Q96"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title\n",
"batch_size = 4\n",
"torch.cuda.empty_cache()\n",
"origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n",
"#masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n",
"masks_list = [f.stem for f in masks_contour_folder.glob(f\"*.{dataset_type}\")]\n",
"\n",
"\n",
"\n",
"\n",
"origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n",
"\n",
"\n",
"if os.path.isfile(split_file):\n",
" with open(split_file, \"rb\") as f:\n",
" splits = pickle.load(f)\n",
"else:\n",
" splits = {}\n",
" splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n",
" splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n",
" with open(split_file, \"wb\") as f:\n",
" pickle.dump(splits, f)\n",
"\n",
"val_test_transforms = torchvision.transforms.Compose([\n",
" Resize((image_size, image_size)),\n",
"])\n",
"\n",
"train_transforms = torchvision.transforms.Compose([\n",
" Pad(200),\n",
" Crop(300),\n",
" val_test_transforms,\n",
"])\n",
"\n",
"datasets = {x: LungDataset(\n",
" splits[x],\n",
" origins_folder,\n",
" #masks_folder,\n",
" masks_contour_folder,\n",
" train_transforms if x == \"train\" else val_test_transforms,\n",
" dataset_type=dataset_type\n",
") for x in [\"train\", \"test\", \"val\"]}\n",
"\n",
"dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size)\n",
" for x in [\"train\", \"test\", \"val\"]}\n",
"\n",
"print(len(dataloaders['train']))\n",
"\n",
"idx = 0\n",
"phase = \"train\"\n",
"\n",
"plt.figure(figsize=(20, 20))\n",
"origin, mask = datasets[phase][idx]\n",
"\n",
"pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
"print(origin.size())\n",
"print(mask.size())\n",
"pil_origin.save(\"1.png\")\n",
"\n",
"\n",
"print(mask.size())\n",
"pil_mask = torchvision.transforms.functional.to_pil_image(mask.float())\n",
"pil_mask.save(\"2.png\")\n",
"plt.subplot(1, 3, 1)\n",
"plt.title(\"origin image\")\n",
"plt.imshow(np.array(pil_origin))\n",
"\n",
"plt.subplot(1, 3, 2)\n",
"plt.title(\"manually labeled mask\")\n",
"plt.imshow(np.array(pil_mask))\n",
"\n",
"plt.subplot(1, 3, 3)\n",
"plt.title(\"blended origin + mask\")\n",
"plt.imshow(np.array(blend(origin, mask)));\n",
"\n",
"plt.savefig(images_folder / \"data-example.png\", bbox_inches='tight')\n",
"plt.show()\n",
"train=True\n",
"model_name = \"ournet_\"+version+\".pt\"\n",
"if train==True:\n",
"\n",
" if os.path.isfile(models_folder / model_name):\n",
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
" print(\"load_state_dict\")\n",
"\n",
" model = model.to(device)\n",
" # optimizer = torch.optim.SGD(unet.parameters(), lr=0.0005, momentum=0.9)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)\n",
"\n",
" train_log_filename = base_path + \"train-log-\"+version+\".txt\"\n",
" epochs = 50\n",
" best_val_loss = np.inf\n",
"\n",
"\n",
" hist = []\n",
"\n",
" for e in range(epochs):\n",
" start_t = time.time()\n",
"\n",
" print(\"Epoch \"+str(e))\n",
" model.train()\n",
"\n",
" train_loss = 0.0\n",
"\n",
" for origins, masks in dataloaders[\"train\"]:\n",
" num = origins.size(0)\n",
"\n",
" origins = origins.to(device)\n",
" #print(masks.size())\n",
" #if dataset_name!=\"dataset\":\n",
" #masks = masks.permute((0,3,1, 2))\n",
" #masks=masks[:,0,:,:]\n",
" #print(masks.size())\n",
"\n",
" masks = masks.to(device)\n",
" optimizer.zero_grad()\n",
" outs = model(origins)\n",
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
" loss = torch.nn.functional.nll_loss(softmax, masks)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" train_loss += loss.item() * num\n",
" print(\".\", end=\"\")\n",
"\n",
" train_loss = train_loss / len(datasets['train'])\n",
" print()\n",
"\n",
" print(\"validation phase\")\n",
" model.eval()\n",
" val_loss = 0.0\n",
" val_jaccard = 0.0\n",
" val_dice = 0.0\n",
"\n",
" for origins, masks in dataloaders[\"val\"]:\n",
" num = origins.size(0)\n",
" origins = origins.to(device)\n",
" masks = masks.to(device)\n",
"\n",
" with torch.no_grad():\n",
" outs = model(origins)\n",
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
" val_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n",
"\n",
" outs = torch.argmax(softmax, dim=1)\n",
" outs = outs.float()\n",
" masks = masks.float()\n",
" val_jaccard += jaccard(masks, outs.float()).item() * num\n",
" val_dice += dice(masks, outs).item() * num\n",
"\n",
" print(\".\", end=\"\")\n",
" val_loss = val_loss / len(datasets[\"val\"])\n",
" val_jaccard = val_jaccard / len(datasets[\"val\"])\n",
" val_dice = val_dice / len(datasets[\"val\"])\n",
" print()\n",
"\n",
" end_t = time.time()\n",
" spended_t = end_t - start_t\n",
"\n",
" with open(train_log_filename, \"a\") as train_log_file:\n",
" report = f\"epoch: {e + 1}/{epochs}, time: {spended_t}, train loss: {train_loss}, \\n\" \\\n",
" + f\"val loss: {val_loss}, val jaccard: {val_jaccard}, val dice: {val_dice}\"\n",
"\n",
" hist.append({\n",
" \"time\": spended_t,\n",
" \"train_loss\": train_loss,\n",
" \"val_loss\": val_loss,\n",
" \"val_jaccard\": val_jaccard,\n",
" \"val_dice\": val_dice,\n",
" })\n",
"\n",
" print(report)\n",
" train_log_file.write(report + \"\\n\")\n",
"\n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" torch.save(model.state_dict(), models_folder / model_name)\n",
" print(\"model saved\")\n",
" train_log_file.write(\"model saved\\n\")\n",
" print()\n",
"\n",
" #if val_jaccard >=0.9179:\n",
" #break\n",
" plt.figure(figsize=(15, 7))\n",
" train_loss_hist = [h[\"train_loss\"] for h in hist]\n",
" plt.plot(range(len(hist)), train_loss_hist, \"b\", label=\"train loss\")\n",
"\n",
" val_loss_hist = [h[\"val_loss\"] for h in hist]\n",
" plt.plot(range(len(hist)), val_loss_hist, \"r\", label=\"val loss\")\n",
"\n",
" val_dice_hist = [h[\"val_dice\"] for h in hist]\n",
" plt.plot(range(len(hist)), val_dice_hist, \"g\", label=\"val dice\")\n",
"\n",
" val_jaccard_hist = [h[\"val_jaccard\"] for h in hist]\n",
" plt.plot(range(len(hist)), val_jaccard_hist, \"y\", label=\"val jaccard\")\n",
"\n",
" plt.legend()\n",
" plt.xlabel(\"epoch\")\n",
" plt.savefig(images_folder / model_name.replace(\".pt\", \"-train-hist.png\"))\n",
"\n",
" time_hist = [h[\"time\"] for h in hist]\n",
" overall_time = sum(time_hist) // 60\n",
" mean_epoch_time = sum(time_hist) / len(hist)\n",
" print(f\"epochs: {len(hist)}, overall time: {overall_time}m, mean epoch time: {mean_epoch_time}s\")\n",
"\n",
" torch.cuda.empty_cache()\n",
"else:\n",
"\n",
" model_name = \"ournet_\"+version+\".pt\"\n",
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
" model.to(device)\n",
" model.eval()\n",
"\n",
" test_loss = 0.0\n",
" test_jaccard = 0.0\n",
" test_dice = 0.0\n",
"\n",
" for origins, masks in dataloaders[\"test\"]:\n",
" num = origins.size(0)\n",
" origins = origins.to(device)\n",
" masks = masks.to(device)\n",
"\n",
" with torch.no_grad():\n",
" outs = model(origins)\n",
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
" test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n",
"\n",
" outs = torch.argmax(softmax, dim=1)\n",
" outs = outs.float()\n",
" masks = masks.float()\n",
" test_jaccard += jaccard(masks, outs).item() * num\n",
" test_dice += dice(masks, outs).item() * num\n",
" print(\".\", end=\"\")\n",
"\n",
" test_loss = test_loss / len(datasets[\"test\"])\n",
" test_jaccard = test_jaccard / len(datasets[\"test\"])\n",
" test_dice = test_dice / len(datasets[\"test\"])\n",
"\n",
" print()\n",
" print(f\"avg test loss: {test_loss}\")\n",
" print(f\"avg test jaccard: {test_jaccard}\")\n",
" print(f\"avg test dice: {test_dice}\")\n",
"\n",
" num_samples = 9\n",
" phase = \"test\"\n",
"\n",
" subset = torch.utils.data.Subset(\n",
" datasets[phase],\n",
" np.random.randint(0, len(datasets[phase]), num_samples)\n",
" )\n",
" random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=1)\n",
" plt.figure(figsize=(20, 25))\n",
"\n",
" for idx, (origin, mask) in enumerate(random_samples_loader):\n",
" plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n",
"\n",
" origin = origin.to(device)\n",
" mask = mask.to(device)\n",
"\n",
" with torch.no_grad():\n",
" out = model(origin)\n",
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
" out = torch.argmax(softmax, dim=1)\n",
"\n",
" jaccard_score = jaccard(mask.float(), out.float()).item()\n",
" dice_score = dice(mask.float(), out.float()).item()\n",
"\n",
" origin = origin[0].to(\"cpu\")\n",
" out = out[0].to(\"cpu\")\n",
" mask = mask[0].to(\"cpu\")\n",
"\n",
" plt.imshow(np.array(blend(origin, mask, out)))\n",
" plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n",
" print(\".\", end=\"\")\n",
" plt.show()\n",
" plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n",
" print()\n",
" print(\"red area - predict\")\n",
" print(\"green area - ground truth\")\n",
" print(\"yellow area - intersection\")\n",
"\n",
"\n",
"\n",
" model_name = \"ournet_\"+version+\".pt\"\n",
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
" model.to(device)\n",
" model.eval()\n",
"\n",
" device\n",
"\n",
" # %%\n",
"\n",
" origin_filename = \"input/dataset/images/CHNCXR_0042_0.png\"\n",
"\n",
" origin = Image.open(origin_filename).convert(\"P\")\n",
" origin = torchvision.transforms.functional.resize(origin, (200, 200))\n",
" origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n",
"\n",
" with torch.no_grad():\n",
" origin = torch.stack([origin])\n",
" origin = origin.to(device)\n",
" out = model(origin)\n",
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
" out = torch.argmax(softmax, dim=1)\n",
"\n",
" origin = origin[0].to(\"cpu\")\n",
" out = out[0].to(\"cpu\")\n",
"\n",
" plt.figure(figsize=(20, 10))\n",
"\n",
" pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
"\n",
" plt.subplot(1, 2, 1)\n",
" plt.title(\"origin image\")\n",
" plt.imshow(np.array(pil_origin))\n",
"\n",
" plt.subplot(1, 2, 2)\n",
" plt.title(\"blended origin + predict\")\n",
" plt.imshow(np.array(blend(origin, out)))\n",
" plt.show()\n"
],
"metadata": {
"id": "sGk2UEtw4JLr"
},
"execution_count": null,
"outputs": []
}
]
}