[1fc123]: / Segthor_up.ipynb

Download this file

1273 lines (1272 with data), 29.8 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FRm2ZqbPpr60"
   },
   "source": [
    "# Image segmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KQbc5Xun-Vlf"
   },
   "source": [
    "## Data preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "import cv2\n",
    "#import nibabel as nib\n",
    "import pdb\n",
    "from matplotlib import pyplot as plt\n",
    "import nibabel as nib\n",
    "from nibabel.testing import data_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncated_range(img):\n",
    "    max_hu = 384\n",
    "    min_hu = -384\n",
    "    img[np.where(img > max_hu)] = max_hu\n",
    "    img[np.where(img < min_hu)] = min_hu\n",
    "    return (img - min_hu) / (max_hu - min_hu) * 255.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './train/'\n",
    "save_path = './data/data_npy/'\n",
    "\n",
    "if not os.path.exists(save_path):\n",
    "    os.makedirs(save_path)\n",
    "\n",
    "files = os.listdir(path)\n",
    "count = 0\n",
    "print('begin processing data')\n",
    "\n",
    "means = []\n",
    "stds = []\n",
    "\n",
    "\n",
    "for i, volume in enumerate(files):\n",
    "    total_imgs = []\n",
    "    cur_file = os.path.join(path, volume)\n",
    "    print(i, cur_file)\n",
    "    cur_save_path = os.path.join(save_path, volume)\n",
    "    if not os.path.exists(cur_save_path):\n",
    "        os.makedirs(cur_save_path)\n",
    "    img = nib.load(os.path.join(cur_file, volume + '.nii'))\n",
    "    img = np.array(img.get_data())\n",
    "    label = nib.load(os.path.join(cur_file, 'GT.nii'))\n",
    "    label = np.array(label.get_data())\n",
    "    img = truncated_range(img)\n",
    "    for idx in range(img.shape[2]):\n",
    "        if idx == 0 or idx == img.shape[2] - 1:\n",
    "            continue\n",
    "        # 2.5D data, using adjacent 3 images\n",
    "        cur_img = img[:, :, idx - 1:idx + 2].astype('uint8')\n",
    "        total_imgs.append(cur_img)\n",
    "        cur_label = label[:, :, idx].astype('uint8')\n",
    "        count += 1\n",
    "        np.save(os.path.join(cur_save_path,volume + '_' + str(idx) + '_image.npy'), cur_img)\n",
    "        np.save(os.path.join(cur_save_path,volume + '_' + str(idx) + '_label.npy'), cur_label)\n",
    "    \n",
    "    total_imgs = np.stack(total_imgs, 3) / 255.\n",
    "    means.append(np.mean(total_imgs))\n",
    "    stds.append(np.std(total_imgs))\n",
    "\n",
    "\n",
    "print('data mean is %f' % np.mean(means))\n",
    "print('data std is %f' % np.std(stds))\n",
    "print('total data size is %f' % count)\n",
    "print('processing data end !')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range (1,41):\n",
    "    j=0\n",
    "    l=1\n",
    "    if(i<10):\n",
    "        os.mkdir('./testing/Patient_0'+str(i))\n",
    "        for k in os.listdir('./data/data_npy/Patient_0'+str(i)):\n",
    "            j=j+1\n",
    "    else:\n",
    "        os.mkdir('./testing/Patient_'+str(i))\n",
    "        for k in os.listdir('./data/data_npy/Patient_'+str(i)):\n",
    "            j=j+1\n",
    "    while(l<=(j/2)):\n",
    "        if(i<10):\n",
    "            lbl_array = np.load('./data/data_npy/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_label.npy')\n",
    "            img = Image.fromarray(lbl_array)\n",
    "            matplotlib.image.imsave('./testing/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_label.png', img)\n",
    "        else:\n",
    "            lbl_array = np.load('./data/data_npy/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_label.npy')\n",
    "            img = Image.fromarray(lbl_array)\n",
    "            matplotlib.image.imsave('./testing/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_label.png', img)\n",
    "        l=l+1\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(41,61):\n",
    "    j=0\n",
    "    l=1\n",
    "    os.mkdir('./testing_2d/Patient_'+str(i))\n",
    "    for k in os.listdir('./data_test/data_npy/Patient_'+str(i)+'.nii'):\n",
    "        j=j+1\n",
    "    while(l<=(j)):\n",
    "        img_array = np.load('./data_test/data_npy/Patient_'+str(i)+'.nii/Patient_'+str(i)+'.nii_'+str(l)+'_image.npy')\n",
    "        matplotlib.image.imsave('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_image.png', img_array)\n",
    "        l=l+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gMZFGT4Qpr64"
   },
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5mXYcsKgpr7A"
   },
   "outputs": [],
   "source": [
    "from fastai.vision import *\n",
    "from fastai.callbacks.hooks import *\n",
    "from fastai.utils.mem import *\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "XcCEVOyHpr7G",
    "outputId": "6119067e-b766-41b7-d650-dedf872a42d8",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cd train_2d/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xLTbzylqpr7a"
   },
   "outputs": [],
   "source": [
    "path_lbl = 'labels1/'\n",
    "path_img = 'ct_images/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8hALaOJzpr7m"
   },
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 69
    },
    "colab_type": "code",
    "id": "r-rnyaxSpr7o",
    "outputId": "7cebd9b4-93f3-4ae8-a27f-55fc75e8a1a4"
   },
   "outputs": [],
   "source": [
    "fnames = get_image_files(path_img)\n",
    "fnames[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 69
    },
    "colab_type": "code",
    "id": "Uipy1i0Hpr73",
    "outputId": "ca5ec985-3777-4fde-b1d8-e7072c0c4a19"
   },
   "outputs": [],
   "source": [
    "lbl_names = get_image_files(path_lbl)\n",
    "lbl_names[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 323
    },
    "colab_type": "code",
    "id": "aVJ8i78Dpr8A",
    "outputId": "be54247f-31fd-4951-e28c-5f52a30b26de"
   },
   "outputs": [],
   "source": [
    "img_f = fnames[4]\n",
    "img = open_image(img_f)\n",
    "img.show(figsize=(5,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_j57JSjipr8G"
   },
   "outputs": [],
   "source": [
    "get_y_fn = lambda x:path_lbl+f'{(x.stem)}_label{x.suffix}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr=np.load('Patient_01_100_image_label.npy')\n",
    "x=plt.imshow(arr, cmap='Greys')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 323
    },
    "colab_type": "code",
    "id": "vneNAIeipr8L",
    "outputId": "49aba5d2-c715-4f53-a9fe-87a9aba4eb5b",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mask = open_mask(get_y_fn(img_f))\n",
    "mask.show(figsize=(5,5), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "colab_type": "code",
    "id": "nVPrFLrupr8Q",
    "outputId": "14a2f066-73ff-4298-8ab3-304b684f59f9",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "src_size = np.array(mask.shape[1:])\n",
    "src_size,mask.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "f3OXtvkSpr8b",
    "outputId": "55f3ff6c-b31b-49c9-a269-7304c8bf9cdf"
   },
   "outputs": [],
   "source": [
    "codes = np.loadtxt('./codes.txt', dtype=str);codes#= np.delete(codes,4);codes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rrxWfznwpr8h"
   },
   "source": [
    "## Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lT9uhZeQpr8i"
   },
   "outputs": [],
   "source": [
    "size = src_size\n",
    "\n",
    "free = gpu_mem_get_free_no_cache()\n",
    "# the max size of bs depends on the available GPU RAM\n",
    "if free > 8200: bs=8\n",
    "else:           bs=4\n",
    "print(f\"using bs={bs}, have {free}MB of GPU RAM free\")\n",
    "bs=2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "2wT3ahktpr8t",
    "outputId": "e1d22796-9f8d-4904-de7c-b601cde0e6fe"
   },
   "outputs": [],
   "source": [
    "free = gpu_mem_get_free_no_cache()\n",
    "free"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xTeb9f-Npr84"
   },
   "outputs": [],
   "source": [
    "src = (SegmentationItemList.from_folder(path_img)\n",
    "       .split_by_rand_pct()\n",
    "       .label_from_func(get_y_fn, classes=codes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "p6p9s3-Ypr88"
   },
   "outputs": [],
   "source": [
    "data = (src.transform(get_transforms(), size=size, tfm_y=True)\n",
    "        .databunch(bs=bs)\n",
    "        .normalize(imagenet_stats))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 513
    },
    "colab_type": "code",
    "id": "S6yarrSKpr9E",
    "outputId": "3c80296c-dfca-4e2c-fa32-d2b2746da3f8"
   },
   "outputs": [],
   "source": [
    "data.show_batch(2, figsize=(10,7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 513
    },
    "colab_type": "code",
    "id": "uhT57Xkgpr9M",
    "outputId": "d9597c0e-7ba0-49af-826a-032fb469a77e"
   },
   "outputs": [],
   "source": [
    "data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zgo7kt5Bpr9R"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "b5y3FjM9pr9R"
   },
   "outputs": [],
   "source": [
    "\n",
    "def dice_coeff(input,target):\n",
    "    \"\"\"\n",
    "    input is a torch variable of size BatchxnclassesxHxW representing log probabilities for each class\n",
    "    target is a 1-hot representation of the groundtruth, shoud have same size as the input\n",
    "    \"\"\"\n",
    "    assert input.size() == target.size(), \"Input sizes must be equal.\"\n",
    "    assert input.dim() == 4, \"Input must be a 4D Tensor.\"\n",
    "    uniques=np.unique(target.numpy())\n",
    "    assert set(list(uniques))<=set([0,1]), \"target must only contain zeros and ones\"\n",
    "\n",
    "    probs=F.softmax(input)\n",
    "    num=probs*target#b,c,h,w--p*g\n",
    "    num=torch.sum(num,dim=3)#b,c,h\n",
    "    num=torch.sum(num,dim=2)\n",
    "    \n",
    "\n",
    "    den1=torch.sum(den1,dim=3)#b,c,h\n",
    "    den1=torch.sum(den1,dim=2)\n",
    "    \n",
    "\n",
    "    den2=target*target#--g^2\n",
    "    den2=torch.sum(den2,dim=3)#b,c,h\n",
    "    den2=torch.sum(den2,dim=2)#b,c\n",
    "    \n",
    "\n",
    "    dice=2*(num/(den1+den2))\n",
    "    dice_eso=dice[:,1:]#we ignore bg dice val, and take the fg\n",
    "\n",
    "    dice_total=-1*torch.sum(dice_eso)/dice_eso.size(0)#divide by batch_sz\n",
    "\n",
    "    return dice_total\n",
    "def iou(target,prediction):\n",
    "    intersection = np.logical_and(target, prediction)\n",
    "    union = np.logical_or(target, prediction)\n",
    "    iou_score = np.sum(intersection) / np.sum(union)\n",
    "    return iou_score\n",
    "  \n",
    "def soft_dice_loss(y_true, y_pred, epsilon=1e-6): \n",
    "    ''' \n",
    "    Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.\n",
    "    Assumes the `channels_last` format.\n",
    "  \n",
    "    # Arguments\n",
    "        y_true: b x X x Y( x Z...) x c One hot encoding of ground truth\n",
    "        y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) \n",
    "        epsilon: Used for numerical stability to avoid divide by zero errors\n",
    "    \n",
    "    # References\n",
    "        V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation \n",
    "        https://arxiv.org/abs/1606.04797\n",
    "        More details on Dice loss formulation \n",
    "        https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)\n",
    "        \n",
    "        Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022\n",
    "    '''\n",
    "    \n",
    "    # skip the batch and class axis for calculating Dice score\n",
    "    axes = tuple(range(1, len(y_pred.shape)-1)) \n",
    "    numerator = 2. * np.sum(y_pred * y_true, axes)\n",
    "    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)\n",
    "    \n",
    "    return 1 - np.mean(numerator / (denominator + epsilon)) # average over classes and batch\n",
    "name2id = {v:k for k,v in enumerate(codes)}\n",
    "void_code = name2id['Void']\n",
    "\n",
    "def acc(input, target):\n",
    "    target = target.squeeze(1)\n",
    "    mask = target != void_code\n",
    "    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bfaw3BcKpr9V"
   },
   "outputs": [],
   "source": [
    "metrics=acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "n0EDeSXPpr9X"
   },
   "outputs": [],
   "source": [
    "wd=1e-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-lMHeE3epr9Z"
   },
   "outputs": [],
   "source": [
    "learn = unet_learner(data, models.resnet18, metrics=metrics, wd=wd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "ib8EtiesiI3J",
    "outputId": "8ae75823-8e78-40bc-d16a-34381763d4ba"
   },
   "outputs": [],
   "source": [
    "from torchsummary import summary\n",
    "summary(learn.model, input_size=(3, 512, 512))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "G7smN9QO9M0k",
    "outputId": "0db355f0-588d-450e-a057-474c480456ba",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "learn.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lyYm11rFrbnk"
   },
   "outputs": [],
   "source": [
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 132
    },
    "colab_type": "code",
    "id": "EzPVb8-9pr9i",
    "outputId": "4fd99a5b-5147-4590-b22d-e1847a8c7102"
   },
   "outputs": [],
   "source": [
    "lr_find(learn)\n",
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tZztkldvpr9r"
   },
   "outputs": [],
   "source": [
    "lr=3e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 391
    },
    "colab_type": "code",
    "id": "y_zi_wBVpr9u",
    "outputId": "e25606c6-cadf-474e-b99d-7f98a7da304a"
   },
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(10, slice(lr), pct_start=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SZdqGQrVpr9y"
   },
   "outputs": [],
   "source": [
    "learn.save('stage-1-imp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZDIptQS_pr91"
   },
   "outputs": [],
   "source": [
    "learn.load('stage-1-imp');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FTEniKzepr95"
   },
   "outputs": [],
   "source": [
    "learn.show_results(rows=3, figsize=(8,9))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## POST-PROCESSING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = open_image('./testing_2d/Patient_41/Patient_41_1_image.png')\n",
    "img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "im_g=learn.predict(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "na1=(np.array(im_g[1][0]))\n",
    "plt.imshow(arr)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(im_g[1][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range (41,61):\n",
    "    j=0\n",
    "    l=1\n",
    "    if(i<10):\n",
    "        os.mkdir('./testing/Patient_0'+str(i))\n",
    "        for k in os.listdir('./data/data_npy/Patient_0'+str(i)):\n",
    "            j=j+1\n",
    "    else:\n",
    "        os.mkdir('./testing/Patient_'+str(i))\n",
    "        for k in os.listdir('./testing_2d/Patient_'+str(i)):\n",
    "            j=j+1\n",
    "    while(l<=(j)):\n",
    "        if(i<10):\n",
    "            img = open_image('./testing_2d/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_image.png')\n",
    "            im_g=learn.predict(img)\n",
    "            img = Image.fromarray(np.array(im_g[1][0]))\n",
    "            matplotlib.image.imsave('./testing/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_image_label.png', np.array(im_g[1][0]))\n",
    "        else:\n",
    "            img = open_image('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_image.png')\n",
    "            im_g=learn.predict(img)\n",
    "            img = Image.fromarray(np.array(im_g[1][0]))\n",
    "            matplotlib.image.imsave('./testing/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_image_label.png', np.array(im_g[1][0]))\n",
    "        l=l+1\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nibabel as nib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import nibabel as nib\n",
    "for i in range(42,61):\n",
    "    j=0\n",
    "    l=1\n",
    "    img = open_image('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(1)+'_image.png')\n",
    "    im_g=learn.predict(img)\n",
    "    na=np.array(im_g[1][0])\n",
    "    for k in os.listdir('./testing_2d/Patient_'+str(i)):\n",
    "        j=j+1\n",
    "    while(l<(j)):\n",
    "        img = open_image('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l+1)+'_image.png')\n",
    "        im_g=learn.predict(img)\n",
    "        a=np.array(im_g[1][0])\n",
    "        na=np.dstack((na,a))\n",
    "        l=l+1\n",
    "    ex=\"test/Patient_\"+str(i)+\".nii\"\n",
    "    img = nib.load(ex)\n",
    "    im1=np.array(img.get_affine())\n",
    "    new_image = nib.Nifti1Image(np.asarray(na,dtype=\"uint8\" ), affine=im1)\n",
    "    nib.save(new_image,'./testing3d/Patient_'+str(i)+'_GT.nii.gz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "na.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IwakG2rrpr9_"
   },
   "outputs": [],
   "source": [
    "learn.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uTsRJkcTpr-A"
   },
   "outputs": [],
   "source": [
    "lrs = slice(lr/400,lr/4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SAYQI7JDpr-C",
    "outputId": "0dea8106-7ea3-4e30-e5d6-b6d4cf18c9f8",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(12, lrs, pct_start=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8mGVVW3ipr-H"
   },
   "outputs": [],
   "source": [
    "learn.save('stage-2');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BoNKO9vHpr-K"
   },
   "source": [
    "## Go big"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CUYh7uAFpr-L"
   },
   "source": [
    "You may have to restart your kernel and come back to this stage if you run out of memory, and may also need to decrease `bs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "ntd0bQLxpr-M",
    "outputId": "413ebf26-5f04-4e04-dae7-97a925d2e28e"
   },
   "outputs": [],
   "source": [
    "#learn.destroy() # uncomment once 1.0.46 is out\n",
    "\n",
    "size = src_size\n",
    "\n",
    "free = gpu_mem_get_free_no_cache()\n",
    "# the max size of bs depends on the available GPU RAM\n",
    "if free > 8200: bs=3\n",
    "else:           bs=1\n",
    "print(f\"using bs={bs}, have {free}MB of GPU RAM free\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "1VnuPjdzpr-P"
   },
   "outputs": [],
   "source": [
    "data = (src.transform(get_transforms(), size=size, tfm_y=True)\n",
    "        .databunch(bs=bs)\n",
    "        .normalize(imagenet_stats))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "code",
    "id": "LqnUbPoTpr-T",
    "outputId": "7c8d7f71-63e5-4a97-f3a0-4a2f87c0eed2",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learn = unet_learner(data, models.resnet50, metrics=metrics, wd=wd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ksf4YOD1pr-X"
   },
   "outputs": [],
   "source": [
    "learn.load('stage-2');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 434
    },
    "colab_type": "code",
    "id": "K8hZUjFzpr-Z",
    "outputId": "7a03f3c3-968d-4270-c027-c1f534c79b0b"
   },
   "outputs": [],
   "source": [
    "lr_find(learn)\n",
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "aRYdIM0lpr-b"
   },
   "outputs": [],
   "source": [
    "lr=1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5u-cUGYtpr-c",
    "outputId": "a66d77c2-4267-4b98-fe52-c54aec433552",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(10, slice(lr), pct_start=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TU38pO5gpr-f"
   },
   "outputs": [],
   "source": [
    "learn.save('stage-1-big')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "F2czKtoFpr-g"
   },
   "outputs": [],
   "source": [
    "learn.load('stage-1-big');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Oicfx00fpr-i"
   },
   "outputs": [],
   "source": [
    "learn.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_6xQhCXApr-j"
   },
   "outputs": [],
   "source": [
    "lrs = slice(1e-6,lr/10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sxTN1UYSpr-l",
    "outputId": "17f9b0c4-6a43-4182-a1f4-57b7ea235d4c"
   },
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(10, lrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tvm_yXCJpr-o"
   },
   "outputs": [],
   "source": [
    "learn.save('stage-2-big')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7Kts9ls7pr-p"
   },
   "outputs": [],
   "source": [
    "learn.load('stage-2-big');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LektCRD2pr-y",
    "outputId": "c73598d8-a91c-406d-a96e-022428a092cb"
   },
   "outputs": [],
   "source": [
    "learn.show_results(rows=3, figsize=(10,10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dNCx0BaYpr-8"
   },
   "source": [
    "## VISUALIZING Axial, Sagittal and Coronal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nii='Patient_41_GT.nii.gz'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ayG7b91Mpr-8",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import nilearn\n",
    "from nilearn import plotting\n",
    "plotting.plot_stat_map(nii)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nii='GT.nii.gz'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nilearn\n",
    "from nilearn import plotting\n",
    "plotting.plot_stat_map(nii)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "BoNKO9vHpr-K",
    "dNCx0BaYpr-8"
   ],
   "name": "SegThor.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}