--- a +++ b/Training.ipynb @@ -0,0 +1,442 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fqPJUT863Ps2" + }, + "outputs": [], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "#!unzip '/content/drive/MyDrive/Batoul_Code/input/CT/images.zip' -d '/content/drive/MyDrive/Batoul_Code/input/CT'" + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torchvision\n", + "import os\n", + "import glob\n", + "import time\n", + "import pickle\n", + "import sys\n", + "sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n", + "sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from pathlib import Path\n", + "from PIL import Image\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "\n", + "from data import LungDataset, blend, Pad, Crop, Resize\n", + "from OurModel import CxlNet\n", + "from metrics import jaccard, dice" + ], + "metadata": { + "id": "N-IZfgid3Xwc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "in_channels=1\n", + "out_channels=2\n", + "batch_norm=True\n", + "upscale_mode=\"bilinear\"\n", + "image_size=512\n", + "def selectModel():\n", + " return CxlNet(\n", + " in_channels=in_channels,\n", + " out_channels=out_channels,\n", + " batch_norm=batch_norm,\n", + " upscale_mode=upscale_mode,\n", + " image_size=image_size)" + ], + "metadata": { + "id": "J349vBjr31Ir" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "dataset_name=\"dataset\"\n", + "dataset_type=\"png\"\n", + "split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n", + "version=\"CxlNet\"\n", + "approach=\"contour\"\n", + "model = selectModel()\n", + "\n", + "base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "device\n", + "\n", + "data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n", + "origins_folder = data_folder / \"images\"\n", + "masks_folder = data_folder / \"masks\"\n", + "masks_contour_folder = data_folder / \"masks_contour\"\n", + "models_folder = Path(base_path+\"models\")\n", + "images_folder = Path(base_path+\"images\")\n", + "\n", + "\n" + ], + "metadata": { + "id": "MhOSosnw4Q96" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title\n", + "batch_size = 4\n", + "torch.cuda.empty_cache()\n", + "origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n", + "#masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n", + "masks_list = [f.stem for f in masks_contour_folder.glob(f\"*.{dataset_type}\")]\n", + "\n", + "\n", + "\n", + "\n", + "origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n", + "\n", + "\n", + "if os.path.isfile(split_file):\n", + " with open(split_file, \"rb\") as f:\n", + " splits = pickle.load(f)\n", + "else:\n", + " splits = {}\n", + " splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n", + " splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n", + " with open(split_file, \"wb\") as f:\n", + " pickle.dump(splits, f)\n", + "\n", + "val_test_transforms = torchvision.transforms.Compose([\n", + " Resize((image_size, image_size)),\n", + "])\n", + "\n", + "train_transforms = torchvision.transforms.Compose([\n", + " Pad(200),\n", + " Crop(300),\n", + " val_test_transforms,\n", + "])\n", + "\n", + "datasets = {x: LungDataset(\n", + " splits[x],\n", + " origins_folder,\n", + " #masks_folder,\n", + " masks_contour_folder,\n", + " train_transforms if x == \"train\" else val_test_transforms,\n", + " dataset_type=dataset_type\n", + ") for x in [\"train\", \"test\", \"val\"]}\n", + "\n", + "dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size)\n", + " for x in [\"train\", \"test\", \"val\"]}\n", + "\n", + "print(len(dataloaders['train']))\n", + "\n", + "idx = 0\n", + "phase = \"train\"\n", + "\n", + "plt.figure(figsize=(20, 20))\n", + "origin, mask = datasets[phase][idx]\n", + "\n", + "pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", + "print(origin.size())\n", + "print(mask.size())\n", + "pil_origin.save(\"1.png\")\n", + "\n", + "\n", + "print(mask.size())\n", + "pil_mask = torchvision.transforms.functional.to_pil_image(mask.float())\n", + "pil_mask.save(\"2.png\")\n", + "plt.subplot(1, 3, 1)\n", + "plt.title(\"origin image\")\n", + "plt.imshow(np.array(pil_origin))\n", + "\n", + "plt.subplot(1, 3, 2)\n", + "plt.title(\"manually labeled mask\")\n", + "plt.imshow(np.array(pil_mask))\n", + "\n", + "plt.subplot(1, 3, 3)\n", + "plt.title(\"blended origin + mask\")\n", + "plt.imshow(np.array(blend(origin, mask)));\n", + "\n", + "plt.savefig(images_folder / \"data-example.png\", bbox_inches='tight')\n", + "plt.show()\n", + "train=True\n", + "model_name = \"ournet_\"+version+\".pt\"\n", + "if train==True:\n", + "\n", + " if os.path.isfile(models_folder / model_name):\n", + " model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", + " print(\"load_state_dict\")\n", + "\n", + " model = model.to(device)\n", + " # optimizer = torch.optim.SGD(unet.parameters(), lr=0.0005, momentum=0.9)\n", + " optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)\n", + "\n", + " train_log_filename = base_path + \"train-log-\"+version+\".txt\"\n", + " epochs = 50\n", + " best_val_loss = np.inf\n", + "\n", + "\n", + " hist = []\n", + "\n", + " for e in range(epochs):\n", + " start_t = time.time()\n", + "\n", + " print(\"Epoch \"+str(e))\n", + " model.train()\n", + "\n", + " train_loss = 0.0\n", + "\n", + " for origins, masks in dataloaders[\"train\"]:\n", + " num = origins.size(0)\n", + "\n", + " origins = origins.to(device)\n", + " #print(masks.size())\n", + " #if dataset_name!=\"dataset\":\n", + " #masks = masks.permute((0,3,1, 2))\n", + " #masks=masks[:,0,:,:]\n", + " #print(masks.size())\n", + "\n", + " masks = masks.to(device)\n", + " optimizer.zero_grad()\n", + " outs = model(origins)\n", + " softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", + " loss = torch.nn.functional.nll_loss(softmax, masks)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss += loss.item() * num\n", + " print(\".\", end=\"\")\n", + "\n", + " train_loss = train_loss / len(datasets['train'])\n", + " print()\n", + "\n", + " print(\"validation phase\")\n", + " model.eval()\n", + " val_loss = 0.0\n", + " val_jaccard = 0.0\n", + " val_dice = 0.0\n", + "\n", + " for origins, masks in dataloaders[\"val\"]:\n", + " num = origins.size(0)\n", + " origins = origins.to(device)\n", + " masks = masks.to(device)\n", + "\n", + " with torch.no_grad():\n", + " outs = model(origins)\n", + " softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", + " val_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n", + "\n", + " outs = torch.argmax(softmax, dim=1)\n", + " outs = outs.float()\n", + " masks = masks.float()\n", + " val_jaccard += jaccard(masks, outs.float()).item() * num\n", + " val_dice += dice(masks, outs).item() * num\n", + "\n", + " print(\".\", end=\"\")\n", + " val_loss = val_loss / len(datasets[\"val\"])\n", + " val_jaccard = val_jaccard / len(datasets[\"val\"])\n", + " val_dice = val_dice / len(datasets[\"val\"])\n", + " print()\n", + "\n", + " end_t = time.time()\n", + " spended_t = end_t - start_t\n", + "\n", + " with open(train_log_filename, \"a\") as train_log_file:\n", + " report = f\"epoch: {e + 1}/{epochs}, time: {spended_t}, train loss: {train_loss}, \\n\" \\\n", + " + f\"val loss: {val_loss}, val jaccard: {val_jaccard}, val dice: {val_dice}\"\n", + "\n", + " hist.append({\n", + " \"time\": spended_t,\n", + " \"train_loss\": train_loss,\n", + " \"val_loss\": val_loss,\n", + " \"val_jaccard\": val_jaccard,\n", + " \"val_dice\": val_dice,\n", + " })\n", + "\n", + " print(report)\n", + " train_log_file.write(report + \"\\n\")\n", + "\n", + " if val_loss < best_val_loss:\n", + " best_val_loss = val_loss\n", + " torch.save(model.state_dict(), models_folder / model_name)\n", + " print(\"model saved\")\n", + " train_log_file.write(\"model saved\\n\")\n", + " print()\n", + "\n", + " #if val_jaccard >=0.9179:\n", + " #break\n", + " plt.figure(figsize=(15, 7))\n", + " train_loss_hist = [h[\"train_loss\"] for h in hist]\n", + " plt.plot(range(len(hist)), train_loss_hist, \"b\", label=\"train loss\")\n", + "\n", + " val_loss_hist = [h[\"val_loss\"] for h in hist]\n", + " plt.plot(range(len(hist)), val_loss_hist, \"r\", label=\"val loss\")\n", + "\n", + " val_dice_hist = [h[\"val_dice\"] for h in hist]\n", + " plt.plot(range(len(hist)), val_dice_hist, \"g\", label=\"val dice\")\n", + "\n", + " val_jaccard_hist = [h[\"val_jaccard\"] for h in hist]\n", + " plt.plot(range(len(hist)), val_jaccard_hist, \"y\", label=\"val jaccard\")\n", + "\n", + " plt.legend()\n", + " plt.xlabel(\"epoch\")\n", + " plt.savefig(images_folder / model_name.replace(\".pt\", \"-train-hist.png\"))\n", + "\n", + " time_hist = [h[\"time\"] for h in hist]\n", + " overall_time = sum(time_hist) // 60\n", + " mean_epoch_time = sum(time_hist) / len(hist)\n", + " print(f\"epochs: {len(hist)}, overall time: {overall_time}m, mean epoch time: {mean_epoch_time}s\")\n", + "\n", + " torch.cuda.empty_cache()\n", + "else:\n", + "\n", + " model_name = \"ournet_\"+version+\".pt\"\n", + " model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", + " model.to(device)\n", + " model.eval()\n", + "\n", + " test_loss = 0.0\n", + " test_jaccard = 0.0\n", + " test_dice = 0.0\n", + "\n", + " for origins, masks in dataloaders[\"test\"]:\n", + " num = origins.size(0)\n", + " origins = origins.to(device)\n", + " masks = masks.to(device)\n", + "\n", + " with torch.no_grad():\n", + " outs = model(origins)\n", + " softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", + " test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n", + "\n", + " outs = torch.argmax(softmax, dim=1)\n", + " outs = outs.float()\n", + " masks = masks.float()\n", + " test_jaccard += jaccard(masks, outs).item() * num\n", + " test_dice += dice(masks, outs).item() * num\n", + " print(\".\", end=\"\")\n", + "\n", + " test_loss = test_loss / len(datasets[\"test\"])\n", + " test_jaccard = test_jaccard / len(datasets[\"test\"])\n", + " test_dice = test_dice / len(datasets[\"test\"])\n", + "\n", + " print()\n", + " print(f\"avg test loss: {test_loss}\")\n", + " print(f\"avg test jaccard: {test_jaccard}\")\n", + " print(f\"avg test dice: {test_dice}\")\n", + "\n", + " num_samples = 9\n", + " phase = \"test\"\n", + "\n", + " subset = torch.utils.data.Subset(\n", + " datasets[phase],\n", + " np.random.randint(0, len(datasets[phase]), num_samples)\n", + " )\n", + " random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=1)\n", + " plt.figure(figsize=(20, 25))\n", + "\n", + " for idx, (origin, mask) in enumerate(random_samples_loader):\n", + " plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n", + "\n", + " origin = origin.to(device)\n", + " mask = mask.to(device)\n", + "\n", + " with torch.no_grad():\n", + " out = model(origin)\n", + " softmax = torch.nn.functional.log_softmax(out, dim=1)\n", + " out = torch.argmax(softmax, dim=1)\n", + "\n", + " jaccard_score = jaccard(mask.float(), out.float()).item()\n", + " dice_score = dice(mask.float(), out.float()).item()\n", + "\n", + " origin = origin[0].to(\"cpu\")\n", + " out = out[0].to(\"cpu\")\n", + " mask = mask[0].to(\"cpu\")\n", + "\n", + " plt.imshow(np.array(blend(origin, mask, out)))\n", + " plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n", + " print(\".\", end=\"\")\n", + " plt.show()\n", + " plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n", + " print()\n", + " print(\"red area - predict\")\n", + " print(\"green area - ground truth\")\n", + " print(\"yellow area - intersection\")\n", + "\n", + "\n", + "\n", + " model_name = \"ournet_\"+version+\".pt\"\n", + " model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", + " model.to(device)\n", + " model.eval()\n", + "\n", + " device\n", + "\n", + " # %%\n", + "\n", + " origin_filename = \"input/dataset/images/CHNCXR_0042_0.png\"\n", + "\n", + " origin = Image.open(origin_filename).convert(\"P\")\n", + " origin = torchvision.transforms.functional.resize(origin, (200, 200))\n", + " origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n", + "\n", + " with torch.no_grad():\n", + " origin = torch.stack([origin])\n", + " origin = origin.to(device)\n", + " out = model(origin)\n", + " softmax = torch.nn.functional.log_softmax(out, dim=1)\n", + " out = torch.argmax(softmax, dim=1)\n", + "\n", + " origin = origin[0].to(\"cpu\")\n", + " out = out[0].to(\"cpu\")\n", + "\n", + " plt.figure(figsize=(20, 10))\n", + "\n", + " pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", + "\n", + " plt.subplot(1, 2, 1)\n", + " plt.title(\"origin image\")\n", + " plt.imshow(np.array(pil_origin))\n", + "\n", + " plt.subplot(1, 2, 2)\n", + " plt.title(\"blended origin + predict\")\n", + " plt.imshow(np.array(blend(origin, out)))\n", + " plt.show()\n" + ], + "metadata": { + "id": "sGk2UEtw4JLr" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file