--- a
+++ b/Segthor_up.ipynb
@@ -0,0 +1,1272 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "FRm2ZqbPpr60"
+   },
+   "source": [
+    "# Image segmentation"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "KQbc5Xun-Vlf"
+   },
+   "source": [
+    "## Data preprocessing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import os\n",
+    "import time\n",
+    "import cv2\n",
+    "#import nibabel as nib\n",
+    "import pdb\n",
+    "from matplotlib import pyplot as plt\n",
+    "import nibabel as nib\n",
+    "from nibabel.testing import data_path"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def truncated_range(img):\n",
+    "    max_hu = 384\n",
+    "    min_hu = -384\n",
+    "    img[np.where(img > max_hu)] = max_hu\n",
+    "    img[np.where(img < min_hu)] = min_hu\n",
+    "    return (img - min_hu) / (max_hu - min_hu) * 255.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "path = './train/'\n",
+    "save_path = './data/data_npy/'\n",
+    "\n",
+    "if not os.path.exists(save_path):\n",
+    "    os.makedirs(save_path)\n",
+    "\n",
+    "files = os.listdir(path)\n",
+    "count = 0\n",
+    "print('begin processing data')\n",
+    "\n",
+    "means = []\n",
+    "stds = []\n",
+    "\n",
+    "\n",
+    "for i, volume in enumerate(files):\n",
+    "    total_imgs = []\n",
+    "    cur_file = os.path.join(path, volume)\n",
+    "    print(i, cur_file)\n",
+    "    cur_save_path = os.path.join(save_path, volume)\n",
+    "    if not os.path.exists(cur_save_path):\n",
+    "        os.makedirs(cur_save_path)\n",
+    "    img = nib.load(os.path.join(cur_file, volume + '.nii'))\n",
+    "    img = np.array(img.get_data())\n",
+    "    label = nib.load(os.path.join(cur_file, 'GT.nii'))\n",
+    "    label = np.array(label.get_data())\n",
+    "    img = truncated_range(img)\n",
+    "    for idx in range(img.shape[2]):\n",
+    "        if idx == 0 or idx == img.shape[2] - 1:\n",
+    "            continue\n",
+    "        # 2.5D data, using adjacent 3 images\n",
+    "        cur_img = img[:, :, idx - 1:idx + 2].astype('uint8')\n",
+    "        total_imgs.append(cur_img)\n",
+    "        cur_label = label[:, :, idx].astype('uint8')\n",
+    "        count += 1\n",
+    "        np.save(os.path.join(cur_save_path,volume + '_' + str(idx) + '_image.npy'), cur_img)\n",
+    "        np.save(os.path.join(cur_save_path,volume + '_' + str(idx) + '_label.npy'), cur_label)\n",
+    "    \n",
+    "    total_imgs = np.stack(total_imgs, 3) / 255.\n",
+    "    means.append(np.mean(total_imgs))\n",
+    "    stds.append(np.std(total_imgs))\n",
+    "\n",
+    "\n",
+    "print('data mean is %f' % np.mean(means))\n",
+    "print('data std is %f' % np.std(stds))\n",
+    "print('total data size is %f' % count)\n",
+    "print('processing data end !')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i in range (1,41):\n",
+    "    j=0\n",
+    "    l=1\n",
+    "    if(i<10):\n",
+    "        os.mkdir('./testing/Patient_0'+str(i))\n",
+    "        for k in os.listdir('./data/data_npy/Patient_0'+str(i)):\n",
+    "            j=j+1\n",
+    "    else:\n",
+    "        os.mkdir('./testing/Patient_'+str(i))\n",
+    "        for k in os.listdir('./data/data_npy/Patient_'+str(i)):\n",
+    "            j=j+1\n",
+    "    while(l<=(j/2)):\n",
+    "        if(i<10):\n",
+    "            lbl_array = np.load('./data/data_npy/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_label.npy')\n",
+    "            img = Image.fromarray(lbl_array)\n",
+    "            matplotlib.image.imsave('./testing/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_label.png', img)\n",
+    "        else:\n",
+    "            lbl_array = np.load('./data/data_npy/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_label.npy')\n",
+    "            img = Image.fromarray(lbl_array)\n",
+    "            matplotlib.image.imsave('./testing/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_label.png', img)\n",
+    "        l=l+1\n",
+    "\n",
+    "    \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i in range(41,61):\n",
+    "    j=0\n",
+    "    l=1\n",
+    "    os.mkdir('./testing_2d/Patient_'+str(i))\n",
+    "    for k in os.listdir('./data_test/data_npy/Patient_'+str(i)+'.nii'):\n",
+    "        j=j+1\n",
+    "    while(l<=(j)):\n",
+    "        img_array = np.load('./data_test/data_npy/Patient_'+str(i)+'.nii/Patient_'+str(i)+'.nii_'+str(l)+'_image.npy')\n",
+    "        matplotlib.image.imsave('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_image.png', img_array)\n",
+    "        l=l+1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "gMZFGT4Qpr64"
+   },
+   "outputs": [],
+   "source": [
+    "%reload_ext autoreload\n",
+    "%autoreload 2\n",
+    "%matplotlib inline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "5mXYcsKgpr7A"
+   },
+   "outputs": [],
+   "source": [
+    "from fastai.vision import *\n",
+    "from fastai.callbacks.hooks import *\n",
+    "from fastai.utils.mem import *\n",
+    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 34
+    },
+    "colab_type": "code",
+    "id": "XcCEVOyHpr7G",
+    "outputId": "6119067e-b766-41b7-d650-dedf872a42d8",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "cd train_2d/"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "xLTbzylqpr7a"
+   },
+   "outputs": [],
+   "source": [
+    "path_lbl = 'labels1/'\n",
+    "path_img = 'ct_images/'"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "8hALaOJzpr7m"
+   },
+   "source": [
+    "## Data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 69
+    },
+    "colab_type": "code",
+    "id": "r-rnyaxSpr7o",
+    "outputId": "7cebd9b4-93f3-4ae8-a27f-55fc75e8a1a4"
+   },
+   "outputs": [],
+   "source": [
+    "fnames = get_image_files(path_img)\n",
+    "fnames[:3]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 69
+    },
+    "colab_type": "code",
+    "id": "Uipy1i0Hpr73",
+    "outputId": "ca5ec985-3777-4fde-b1d8-e7072c0c4a19"
+   },
+   "outputs": [],
+   "source": [
+    "lbl_names = get_image_files(path_lbl)\n",
+    "lbl_names[:3]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 323
+    },
+    "colab_type": "code",
+    "id": "aVJ8i78Dpr8A",
+    "outputId": "be54247f-31fd-4951-e28c-5f52a30b26de"
+   },
+   "outputs": [],
+   "source": [
+    "img_f = fnames[4]\n",
+    "img = open_image(img_f)\n",
+    "img.show(figsize=(5,5))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "_j57JSjipr8G"
+   },
+   "outputs": [],
+   "source": [
+    "get_y_fn = lambda x:path_lbl+f'{(x.stem)}_label{x.suffix}'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "arr=np.load('Patient_01_100_image_label.npy')\n",
+    "x=plt.imshow(arr, cmap='Greys')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 323
+    },
+    "colab_type": "code",
+    "id": "vneNAIeipr8L",
+    "outputId": "49aba5d2-c715-4f53-a9fe-87a9aba4eb5b",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "mask = open_mask(get_y_fn(img_f))\n",
+    "mask.show(figsize=(5,5), alpha=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 139
+    },
+    "colab_type": "code",
+    "id": "nVPrFLrupr8Q",
+    "outputId": "14a2f066-73ff-4298-8ab3-304b684f59f9",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "src_size = np.array(mask.shape[1:])\n",
+    "src_size,mask.data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 34
+    },
+    "colab_type": "code",
+    "id": "f3OXtvkSpr8b",
+    "outputId": "55f3ff6c-b31b-49c9-a269-7304c8bf9cdf"
+   },
+   "outputs": [],
+   "source": [
+    "codes = np.loadtxt('./codes.txt', dtype=str);codes#= np.delete(codes,4);codes"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "rrxWfznwpr8h"
+   },
+   "source": [
+    "## Datasets"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "lT9uhZeQpr8i"
+   },
+   "outputs": [],
+   "source": [
+    "size = src_size\n",
+    "\n",
+    "free = gpu_mem_get_free_no_cache()\n",
+    "# the max size of bs depends on the available GPU RAM\n",
+    "if free > 8200: bs=8\n",
+    "else:           bs=4\n",
+    "print(f\"using bs={bs}, have {free}MB of GPU RAM free\")\n",
+    "bs=2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 34
+    },
+    "colab_type": "code",
+    "id": "2wT3ahktpr8t",
+    "outputId": "e1d22796-9f8d-4904-de7c-b601cde0e6fe"
+   },
+   "outputs": [],
+   "source": [
+    "free = gpu_mem_get_free_no_cache()\n",
+    "free"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "xTeb9f-Npr84"
+   },
+   "outputs": [],
+   "source": [
+    "src = (SegmentationItemList.from_folder(path_img)\n",
+    "       .split_by_rand_pct()\n",
+    "       .label_from_func(get_y_fn, classes=codes))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "p6p9s3-Ypr88"
+   },
+   "outputs": [],
+   "source": [
+    "data = (src.transform(get_transforms(), size=size, tfm_y=True)\n",
+    "        .databunch(bs=bs)\n",
+    "        .normalize(imagenet_stats))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 513
+    },
+    "colab_type": "code",
+    "id": "S6yarrSKpr9E",
+    "outputId": "3c80296c-dfca-4e2c-fa32-d2b2746da3f8"
+   },
+   "outputs": [],
+   "source": [
+    "data.show_batch(2, figsize=(10,7))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 513
+    },
+    "colab_type": "code",
+    "id": "uhT57Xkgpr9M",
+    "outputId": "d9597c0e-7ba0-49af-826a-032fb469a77e"
+   },
+   "outputs": [],
+   "source": [
+    "data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "zgo7kt5Bpr9R"
+   },
+   "source": [
+    "## Model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "b5y3FjM9pr9R"
+   },
+   "outputs": [],
+   "source": [
+    "\n",
+    "def dice_coeff(input,target):\n",
+    "    \"\"\"\n",
+    "    input is a torch variable of size BatchxnclassesxHxW representing log probabilities for each class\n",
+    "    target is a 1-hot representation of the groundtruth, shoud have same size as the input\n",
+    "    \"\"\"\n",
+    "    assert input.size() == target.size(), \"Input sizes must be equal.\"\n",
+    "    assert input.dim() == 4, \"Input must be a 4D Tensor.\"\n",
+    "    uniques=np.unique(target.numpy())\n",
+    "    assert set(list(uniques))<=set([0,1]), \"target must only contain zeros and ones\"\n",
+    "\n",
+    "    probs=F.softmax(input)\n",
+    "    num=probs*target#b,c,h,w--p*g\n",
+    "    num=torch.sum(num,dim=3)#b,c,h\n",
+    "    num=torch.sum(num,dim=2)\n",
+    "    \n",
+    "\n",
+    "    den1=torch.sum(den1,dim=3)#b,c,h\n",
+    "    den1=torch.sum(den1,dim=2)\n",
+    "    \n",
+    "\n",
+    "    den2=target*target#--g^2\n",
+    "    den2=torch.sum(den2,dim=3)#b,c,h\n",
+    "    den2=torch.sum(den2,dim=2)#b,c\n",
+    "    \n",
+    "\n",
+    "    dice=2*(num/(den1+den2))\n",
+    "    dice_eso=dice[:,1:]#we ignore bg dice val, and take the fg\n",
+    "\n",
+    "    dice_total=-1*torch.sum(dice_eso)/dice_eso.size(0)#divide by batch_sz\n",
+    "\n",
+    "    return dice_total\n",
+    "def iou(target,prediction):\n",
+    "    intersection = np.logical_and(target, prediction)\n",
+    "    union = np.logical_or(target, prediction)\n",
+    "    iou_score = np.sum(intersection) / np.sum(union)\n",
+    "    return iou_score\n",
+    "  \n",
+    "def soft_dice_loss(y_true, y_pred, epsilon=1e-6): \n",
+    "    ''' \n",
+    "    Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.\n",
+    "    Assumes the `channels_last` format.\n",
+    "  \n",
+    "    # Arguments\n",
+    "        y_true: b x X x Y( x Z...) x c One hot encoding of ground truth\n",
+    "        y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) \n",
+    "        epsilon: Used for numerical stability to avoid divide by zero errors\n",
+    "    \n",
+    "    # References\n",
+    "        V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation \n",
+    "        https://arxiv.org/abs/1606.04797\n",
+    "        More details on Dice loss formulation \n",
+    "        https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)\n",
+    "        \n",
+    "        Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022\n",
+    "    '''\n",
+    "    \n",
+    "    # skip the batch and class axis for calculating Dice score\n",
+    "    axes = tuple(range(1, len(y_pred.shape)-1)) \n",
+    "    numerator = 2. * np.sum(y_pred * y_true, axes)\n",
+    "    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)\n",
+    "    \n",
+    "    return 1 - np.mean(numerator / (denominator + epsilon)) # average over classes and batch\n",
+    "name2id = {v:k for k,v in enumerate(codes)}\n",
+    "void_code = name2id['Void']\n",
+    "\n",
+    "def acc(input, target):\n",
+    "    target = target.squeeze(1)\n",
+    "    mask = target != void_code\n",
+    "    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "bfaw3BcKpr9V"
+   },
+   "outputs": [],
+   "source": [
+    "metrics=acc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "n0EDeSXPpr9X"
+   },
+   "outputs": [],
+   "source": [
+    "wd=1e-2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "-lMHeE3epr9Z"
+   },
+   "outputs": [],
+   "source": [
+    "learn = unet_learner(data, models.resnet18, metrics=metrics, wd=wd)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 1000
+    },
+    "colab_type": "code",
+    "id": "ib8EtiesiI3J",
+    "outputId": "8ae75823-8e78-40bc-d16a-34381763d4ba"
+   },
+   "outputs": [],
+   "source": [
+    "from torchsummary import summary\n",
+    "summary(learn.model, input_size=(3, 512, 512))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 1000
+    },
+    "colab_type": "code",
+    "id": "G7smN9QO9M0k",
+    "outputId": "0db355f0-588d-450e-a057-474c480456ba",
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "learn.model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "lyYm11rFrbnk"
+   },
+   "outputs": [],
+   "source": [
+    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 132
+    },
+    "colab_type": "code",
+    "id": "EzPVb8-9pr9i",
+    "outputId": "4fd99a5b-5147-4590-b22d-e1847a8c7102"
+   },
+   "outputs": [],
+   "source": [
+    "lr_find(learn)\n",
+    "learn.recorder.plot()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "tZztkldvpr9r"
+   },
+   "outputs": [],
+   "source": [
+    "lr=3e-3"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 391
+    },
+    "colab_type": "code",
+    "id": "y_zi_wBVpr9u",
+    "outputId": "e25606c6-cadf-474e-b99d-7f98a7da304a"
+   },
+   "outputs": [],
+   "source": [
+    "learn.fit_one_cycle(10, slice(lr), pct_start=0.9)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "SZdqGQrVpr9y"
+   },
+   "outputs": [],
+   "source": [
+    "learn.save('stage-1-imp')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "ZDIptQS_pr91"
+   },
+   "outputs": [],
+   "source": [
+    "learn.load('stage-1-imp');"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "FTEniKzepr95"
+   },
+   "outputs": [],
+   "source": [
+    "learn.show_results(rows=3, figsize=(8,9))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## POST-PROCESSING"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "img = open_image('./testing_2d/Patient_41/Patient_41_1_image.png')\n",
+    "img"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "im_g=learn.predict(img)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "na1=(np.array(im_g[1][0]))\n",
+    "plt.imshow(arr)\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "np.array(im_g[1][0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import matplotlib\n",
+    "from matplotlib import pyplot as plt"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i in range (41,61):\n",
+    "    j=0\n",
+    "    l=1\n",
+    "    if(i<10):\n",
+    "        os.mkdir('./testing/Patient_0'+str(i))\n",
+    "        for k in os.listdir('./data/data_npy/Patient_0'+str(i)):\n",
+    "            j=j+1\n",
+    "    else:\n",
+    "        os.mkdir('./testing/Patient_'+str(i))\n",
+    "        for k in os.listdir('./testing_2d/Patient_'+str(i)):\n",
+    "            j=j+1\n",
+    "    while(l<=(j)):\n",
+    "        if(i<10):\n",
+    "            img = open_image('./testing_2d/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_image.png')\n",
+    "            im_g=learn.predict(img)\n",
+    "            img = Image.fromarray(np.array(im_g[1][0]))\n",
+    "            matplotlib.image.imsave('./testing/Patient_0'+str(i)+'/Patient_0'+str(i)+'_'+str(l)+'_image_label.png', np.array(im_g[1][0]))\n",
+    "        else:\n",
+    "            img = open_image('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_image.png')\n",
+    "            im_g=learn.predict(img)\n",
+    "            img = Image.fromarray(np.array(im_g[1][0]))\n",
+    "            matplotlib.image.imsave('./testing/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l)+'_image_label.png', np.array(im_g[1][0]))\n",
+    "        l=l+1\n",
+    "            "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import nibabel as nib"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "import nibabel as nib\n",
+    "for i in range(42,61):\n",
+    "    j=0\n",
+    "    l=1\n",
+    "    img = open_image('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(1)+'_image.png')\n",
+    "    im_g=learn.predict(img)\n",
+    "    na=np.array(im_g[1][0])\n",
+    "    for k in os.listdir('./testing_2d/Patient_'+str(i)):\n",
+    "        j=j+1\n",
+    "    while(l<(j)):\n",
+    "        img = open_image('./testing_2d/Patient_'+str(i)+'/Patient_'+str(i)+'_'+str(l+1)+'_image.png')\n",
+    "        im_g=learn.predict(img)\n",
+    "        a=np.array(im_g[1][0])\n",
+    "        na=np.dstack((na,a))\n",
+    "        l=l+1\n",
+    "    ex=\"test/Patient_\"+str(i)+\".nii\"\n",
+    "    img = nib.load(ex)\n",
+    "    im1=np.array(img.get_affine())\n",
+    "    new_image = nib.Nifti1Image(np.asarray(na,dtype=\"uint8\" ), affine=im1)\n",
+    "    nib.save(new_image,'./testing3d/Patient_'+str(i)+'_GT.nii.gz')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "na.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "IwakG2rrpr9_"
+   },
+   "outputs": [],
+   "source": [
+    "learn.unfreeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "uTsRJkcTpr-A"
+   },
+   "outputs": [],
+   "source": [
+    "lrs = slice(lr/400,lr/4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "SAYQI7JDpr-C",
+    "outputId": "0dea8106-7ea3-4e30-e5d6-b6d4cf18c9f8",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "learn.fit_one_cycle(12, lrs, pct_start=0.8)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "8mGVVW3ipr-H"
+   },
+   "outputs": [],
+   "source": [
+    "learn.save('stage-2');"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "BoNKO9vHpr-K"
+   },
+   "source": [
+    "## Go big"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "CUYh7uAFpr-L"
+   },
+   "source": [
+    "You may have to restart your kernel and come back to this stage if you run out of memory, and may also need to decrease `bs`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 34
+    },
+    "colab_type": "code",
+    "id": "ntd0bQLxpr-M",
+    "outputId": "413ebf26-5f04-4e04-dae7-97a925d2e28e"
+   },
+   "outputs": [],
+   "source": [
+    "#learn.destroy() # uncomment once 1.0.46 is out\n",
+    "\n",
+    "size = src_size\n",
+    "\n",
+    "free = gpu_mem_get_free_no_cache()\n",
+    "# the max size of bs depends on the available GPU RAM\n",
+    "if free > 8200: bs=3\n",
+    "else:           bs=1\n",
+    "print(f\"using bs={bs}, have {free}MB of GPU RAM free\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "1VnuPjdzpr-P"
+   },
+   "outputs": [],
+   "source": [
+    "data = (src.transform(get_transforms(), size=size, tfm_y=True)\n",
+    "        .databunch(bs=bs)\n",
+    "        .normalize(imagenet_stats))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 52
+    },
+    "colab_type": "code",
+    "id": "LqnUbPoTpr-T",
+    "outputId": "7c8d7f71-63e5-4a97-f3a0-4a2f87c0eed2",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "learn = unet_learner(data, models.resnet50, metrics=metrics, wd=wd)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "ksf4YOD1pr-X"
+   },
+   "outputs": [],
+   "source": [
+    "learn.load('stage-2');"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 434
+    },
+    "colab_type": "code",
+    "id": "K8hZUjFzpr-Z",
+    "outputId": "7a03f3c3-968d-4270-c027-c1f534c79b0b"
+   },
+   "outputs": [],
+   "source": [
+    "lr_find(learn)\n",
+    "learn.recorder.plot()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "aRYdIM0lpr-b"
+   },
+   "outputs": [],
+   "source": [
+    "lr=1e-3"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "5u-cUGYtpr-c",
+    "outputId": "a66d77c2-4267-4b98-fe52-c54aec433552",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "learn.fit_one_cycle(10, slice(lr), pct_start=0.8)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "TU38pO5gpr-f"
+   },
+   "outputs": [],
+   "source": [
+    "learn.save('stage-1-big')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "F2czKtoFpr-g"
+   },
+   "outputs": [],
+   "source": [
+    "learn.load('stage-1-big');"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "Oicfx00fpr-i"
+   },
+   "outputs": [],
+   "source": [
+    "learn.unfreeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "_6xQhCXApr-j"
+   },
+   "outputs": [],
+   "source": [
+    "lrs = slice(1e-6,lr/10)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "sxTN1UYSpr-l",
+    "outputId": "17f9b0c4-6a43-4182-a1f4-57b7ea235d4c"
+   },
+   "outputs": [],
+   "source": [
+    "learn.fit_one_cycle(10, lrs)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "tvm_yXCJpr-o"
+   },
+   "outputs": [],
+   "source": [
+    "learn.save('stage-2-big')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "7Kts9ls7pr-p"
+   },
+   "outputs": [],
+   "source": [
+    "learn.load('stage-2-big');"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "LektCRD2pr-y",
+    "outputId": "c73598d8-a91c-406d-a96e-022428a092cb"
+   },
+   "outputs": [],
+   "source": [
+    "learn.show_results(rows=3, figsize=(10,10))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "dNCx0BaYpr-8"
+   },
+   "source": [
+    "## VISUALIZING Axial, Sagittal and Coronal"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "nii='Patient_41_GT.nii.gz'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "ayG7b91Mpr-8",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "import nilearn\n",
+    "from nilearn import plotting\n",
+    "plotting.plot_stat_map(nii)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "nii='GT.nii.gz'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import nilearn\n",
+    "from nilearn import plotting\n",
+    "plotting.plot_stat_map(nii)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "accelerator": "GPU",
+  "colab": {
+   "collapsed_sections": [
+    "BoNKO9vHpr-K",
+    "dNCx0BaYpr-8"
+   ],
+   "name": "SegThor.ipynb",
+   "provenance": [],
+   "version": "0.3.2"
+  },
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}