Switch to side-by-side view

--- a
+++ b/Experimentations/Exp14-Label smoothing and noisy ground truth.ipynb
@@ -0,0 +1,519 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n",
+    "We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a  way to segment images based on noisy ground truth"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "from torch.autograd import Variable\n",
+    "import torch.optim as optim\n",
+    "import torchvision\n",
+    "from torchvision import datasets, models\n",
+    "from torchvision import transforms as T\n",
+    "from torch.utils.data import DataLoader, Dataset\n",
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "import os\n",
+    "import time\n",
+    "import pandas as pd\n",
+    "from skimage import io, transform\n",
+    "import matplotlib.image as mpimg\n",
+    "from PIL import Image\n",
+    "from sklearn.metrics import roc_auc_score\n",
+    "import torch.nn.functional as F\n",
+    "import scipy\n",
+    "import random\n",
+    "import pickle\n",
+    "import scipy.io as sio\n",
+    "import itertools\n",
+    "from scipy.ndimage.interpolation import shift\n",
+    "import copy\n",
+    "import warnings\n",
+    "warnings.filterwarnings(\"ignore\")\n",
+    "%matplotlib inline\n",
+    "plt.ion()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Import Dataloader Class and other utilities"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "from dataloader_2d import *\n",
+    "from dataloader_3d import *"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Build Data loader objects"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "train_path = '/beegfs/ark576/new_knee_data/train'\n",
+    "val_path = '/beegfs/ark576/new_knee_data/val'\n",
+    "test_path = '/beegfs/ark576/new_knee_data/test'\n",
+    "\n",
+    "train_file_names = sorted(pickle.load(open(train_path + '/train_file_names.p','rb')))\n",
+    "val_file_names = sorted(pickle.load(open(val_path + '/val_file_names.p','rb')))\n",
+    "test_file_names = sorted(pickle.load(open(test_path + '/test_file_names.p','rb')))\n",
+    "\n",
+    "transformed_dataset = {'train': KneeMRIDataset(train_path,train_file_names, train_data= True, flipping=False, normalize= True),\n",
+    "                       'validate': KneeMRIDataset(val_path,val_file_names, normalize= True),\n",
+    "                       'test': KneeMRIDataset(test_path,test_file_names, normalize= True)\n",
+    "                                          }\n",
+    "\n",
+    "dataloader = {x: DataLoader(transformed_dataset[x], batch_size=5,\n",
+    "                        shuffle=True, num_workers=0) for x in ['train', 'validate','test']}\n",
+    "data_sizes ={x: len(transformed_dataset[x]) for x in ['train', 'validate','test']}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "im, seg_F, seg_P, seg_T,_ = next(iter(dataloader['train']))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Find Max and min values of Images (all 7 contrasts), of Fractional Anisotropy maps and of Mean Diffusivity maps for image normalization"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "min_fa = np.inf\n",
+    "min_md = np.inf\n",
+    "min_image = np.inf\n",
+    "max_fa = 0\n",
+    "max_md = 0\n",
+    "max_image = 0\n",
+    "for data in dataloader['train']:\n",
+    "    if min_fa > torch.min(data[0][:,7,:,:]):\n",
+    "        min_fa = torch.min(data[0][:,7,:,:])\n",
+    "    if min_md > torch.min(data[0][:,8,:,:]):\n",
+    "        min_md = torch.min(data[0][:,8:,:])\n",
+    "    if min_image > torch.min(data[0][:,:7,:,:]):\n",
+    "        min_image = torch.min(data[0][:,:7,:,:])\n",
+    "    if max_fa < torch.max(data[0][:,7,:,:]):\n",
+    "        max_fa = torch.max(data[0][:,7,:,:])\n",
+    "    if max_md < torch.max(data[0][:,8,:,:]):\n",
+    "        max_md = torch.max(data[0][:,8,:,:])\n",
+    "    if max_image < torch.max(data[0][:,:7,:,:]):\n",
+    "        max_image = torch.max(data[0][:,:7,:,:])\n",
+    "norm_values = (max_image, min_image, max_fa, min_fa, max_md, min_md)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Import Models"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "from unet_3d import *\n",
+    "from unet_basic_dilated import *\n",
+    "from vnet import *\n",
+    "from ensemble_model import *"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "seg_sum = torch.zeros(3)\n",
+    "for i, data in enumerate(dataloader['train']):\n",
+    "    input, segF, segP, segT,_ = data\n",
+    "    seg_sum[0] += torch.sum(segF)\n",
+    "    seg_sum[1] += torch.sum(segP)\n",
+    "    seg_sum[2] += torch.sum(segT)\n",
+    "mean_s_sum = seg_sum/i"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Import Loss functions and all other utility functions like functions for saving models, for visualizing images, etc."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "from utils import *"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Import all the Training and evaluate functions to evaluate the models"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "from train_2d import *\n",
+    "from train_3d import *\n",
+    "from train_ensemble import *\n",
+    "from evaluate_2d import *\n",
+    "from evaluate_3d import *\n",
+    "from evaluate_ensemble import *"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n",
+    "We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a  way to segment images based on noisy ground truth"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 68,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "unet_exp_noisy = Unet_dilated_small(9,4,int_var=40,dilated=False).cuda()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 69,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "410524"
+      ]
+     },
+     "execution_count": 69,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "count_parameters(unet_exp_noisy)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 70,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "optimizer_unet_exp_noisy = optim.Adam(unet_exp_noisy.parameters(),lr = 1e-4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 80,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch: 0, Phase: train, epoch loss: 0.0069, Dice Score (class 0): 0.4548, Dice Score (class 1): 0.3627,Dice Score (class 2): 0.2732\n",
+      "----------\n",
+      "Epoch: 0, Phase: validate, epoch loss: 0.0107, Dice Score (class 0): 0.4652, Dice Score (class 1): 0.2242,Dice Score (class 2): 0.2691\n",
+      "----------\n",
+      "Epoch: 1, Phase: train, epoch loss: 0.0066, Dice Score (class 0): 0.4464, Dice Score (class 1): 0.3631,Dice Score (class 2): 0.2697\n",
+      "----------\n",
+      "Epoch: 1, Phase: validate, epoch loss: 0.0116, Dice Score (class 0): 0.4738, Dice Score (class 1): 0.1400,Dice Score (class 2): 0.2687\n",
+      "----------\n",
+      "Epoch: 2, Phase: train, epoch loss: 0.0063, Dice Score (class 0): 0.4417, Dice Score (class 1): 0.3004,Dice Score (class 2): 0.2691\n",
+      "----------\n",
+      "Epoch: 2, Phase: validate, epoch loss: 0.0083, Dice Score (class 0): 0.4532, Dice Score (class 1): 0.1267,Dice Score (class 2): 0.2386\n",
+      "----------\n",
+      "Epoch: 3, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4365, Dice Score (class 1): 0.3346,Dice Score (class 2): 0.2603\n",
+      "----------\n",
+      "Epoch: 3, Phase: validate, epoch loss: 0.0090, Dice Score (class 0): 0.4624, Dice Score (class 1): 0.1272,Dice Score (class 2): 0.2762\n",
+      "----------\n",
+      "Epoch: 4, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4264, Dice Score (class 1): 0.2908,Dice Score (class 2): 0.2789\n",
+      "----------\n",
+      "Epoch: 4, Phase: validate, epoch loss: 0.0106, Dice Score (class 0): 0.4784, Dice Score (class 1): 0.0917,Dice Score (class 2): 0.2196\n",
+      "----------\n",
+      "Epoch: 5, Phase: train, epoch loss: 0.0053, Dice Score (class 0): 0.4341, Dice Score (class 1): 0.3044,Dice Score (class 2): 0.2540\n",
+      "----------\n",
+      "Epoch: 5, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4828, Dice Score (class 1): 0.1812,Dice Score (class 2): 0.2497\n",
+      "----------\n",
+      "Epoch: 6, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.4285, Dice Score (class 1): 0.2792,Dice Score (class 2): 0.2567\n",
+      "----------\n",
+      "Epoch: 6, Phase: validate, epoch loss: 0.0081, Dice Score (class 0): 0.4304, Dice Score (class 1): 0.0803,Dice Score (class 2): 0.2250\n",
+      "----------\n",
+      "Epoch: 7, Phase: train, epoch loss: 0.0049, Dice Score (class 0): 0.4271, Dice Score (class 1): 0.3015,Dice Score (class 2): 0.2550\n",
+      "----------\n",
+      "Epoch: 7, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4339, Dice Score (class 1): 0.1199,Dice Score (class 2): 0.2402\n",
+      "----------\n",
+      "Epoch: 8, Phase: train, epoch loss: 0.0046, Dice Score (class 0): 0.4336, Dice Score (class 1): 0.3101,Dice Score (class 2): 0.2492\n",
+      "----------\n",
+      "Epoch: 8, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4146, Dice Score (class 1): 0.0966,Dice Score (class 2): 0.2273\n",
+      "----------\n",
+      "Epoch: 9, Phase: train, epoch loss: 0.0047, Dice Score (class 0): 0.4042, Dice Score (class 1): 0.3042,Dice Score (class 2): 0.2402\n",
+      "----------\n",
+      "Epoch: 9, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4546, Dice Score (class 1): 0.0937,Dice Score (class 2): 0.2365\n",
+      "----------\n",
+      "Epoch: 10, Phase: train, epoch loss: 0.0048, Dice Score (class 0): 0.4017, Dice Score (class 1): 0.2660,Dice Score (class 2): 0.2579\n",
+      "----------\n",
+      "Epoch: 10, Phase: validate, epoch loss: 0.0093, Dice Score (class 0): 0.3297, Dice Score (class 1): 0.1008,Dice Score (class 2): 0.3376\n",
+      "----------\n",
+      "Epoch: 11, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.3732, Dice Score (class 1): 0.2240,Dice Score (class 2): 0.2164\n",
+      "----------\n",
+      "Epoch: 11, Phase: validate, epoch loss: 0.0080, Dice Score (class 0): 0.3507, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.1992\n",
+      "----------\n",
+      "Epoch: 12, Phase: train, epoch loss: 0.0044, Dice Score (class 0): 0.4045, Dice Score (class 1): 0.2663,Dice Score (class 2): 0.2398\n",
+      "----------\n",
+      "Epoch: 12, Phase: validate, epoch loss: 0.0054, Dice Score (class 0): 0.3737, Dice Score (class 1): 0.1181,Dice Score (class 2): 0.2016\n",
+      "----------\n",
+      "Epoch: 13, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.4121, Dice Score (class 1): 0.2953,Dice Score (class 2): 0.2487\n",
+      "----------\n",
+      "Epoch: 13, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4482, Dice Score (class 1): 0.1157,Dice Score (class 2): 0.2612\n",
+      "----------\n",
+      "Epoch: 14, Phase: train, epoch loss: 0.0037, Dice Score (class 0): 0.4235, Dice Score (class 1): 0.3123,Dice Score (class 2): 0.2620\n",
+      "----------\n",
+      "Epoch: 14, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4270, Dice Score (class 1): 0.0954,Dice Score (class 2): 0.2708\n",
+      "----------\n",
+      "Epoch: 15, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.4144, Dice Score (class 1): 0.2824,Dice Score (class 2): 0.2676\n",
+      "----------\n",
+      "Epoch: 15, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4503, Dice Score (class 1): 0.1200,Dice Score (class 2): 0.2609\n",
+      "----------\n",
+      "Epoch: 16, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4183, Dice Score (class 1): 0.2898,Dice Score (class 2): 0.2592\n",
+      "----------\n",
+      "Epoch: 16, Phase: validate, epoch loss: 0.0061, Dice Score (class 0): 0.4537, Dice Score (class 1): 0.1223,Dice Score (class 2): 0.3014\n",
+      "----------\n",
+      "Epoch: 17, Phase: train, epoch loss: 0.0035, Dice Score (class 0): 0.4004, Dice Score (class 1): 0.2839,Dice Score (class 2): 0.2320\n",
+      "----------\n",
+      "Epoch: 17, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4409, Dice Score (class 1): 0.1723,Dice Score (class 2): 0.3229\n",
+      "----------\n",
+      "Epoch: 18, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4095, Dice Score (class 1): 0.2638,Dice Score (class 2): 0.2377\n",
+      "----------\n",
+      "Epoch: 18, Phase: validate, epoch loss: 0.0124, Dice Score (class 0): 0.4778, Dice Score (class 1): 0.0729,Dice Score (class 2): 0.2672\n",
+      "----------\n",
+      "Epoch: 19, Phase: train, epoch loss: 0.0032, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.3173,Dice Score (class 2): 0.2508\n",
+      "----------\n",
+      "Epoch: 19, Phase: validate, epoch loss: 0.0087, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.2357,Dice Score (class 2): 0.3106\n",
+      "----------\n",
+      "Epoch: 20, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.4090, Dice Score (class 1): 0.3206,Dice Score (class 2): 0.2621\n",
+      "----------\n",
+      "Epoch: 20, Phase: validate, epoch loss: 0.0053, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.1603,Dice Score (class 2): 0.2504\n",
+      "----------\n",
+      "Epoch: 21, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.3843, Dice Score (class 1): 0.2829,Dice Score (class 2): 0.2470\n",
+      "----------\n",
+      "Epoch: 21, Phase: validate, epoch loss: 0.0112, Dice Score (class 0): 0.3247, Dice Score (class 1): 0.0295,Dice Score (class 2): 0.2692\n",
+      "----------\n",
+      "Epoch: 22, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.3750, Dice Score (class 1): 0.1967,Dice Score (class 2): 0.2209\n",
+      "----------\n",
+      "Epoch: 22, Phase: validate, epoch loss: 0.0075, Dice Score (class 0): 0.4240, Dice Score (class 1): 0.1111,Dice Score (class 2): 0.3162\n",
+      "----------\n",
+      "Epoch: 23, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.3883, Dice Score (class 1): 0.2905,Dice Score (class 2): 0.2495\n",
+      "----------\n",
+      "Epoch: 23, Phase: validate, epoch loss: 0.0079, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.2906\n",
+      "----------\n",
+      "Epoch: 24, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4021, Dice Score (class 1): 0.3212,Dice Score (class 2): 0.2533\n",
+      "----------\n",
+      "Epoch: 24, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4512, Dice Score (class 1): 0.2062,Dice Score (class 2): 0.3409\n",
+      "----------\n",
+      "Epoch: 25, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.4044, Dice Score (class 1): 0.3422,Dice Score (class 2): 0.2634\n",
+      "----------\n",
+      "Epoch: 25, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4467, Dice Score (class 1): 0.3159,Dice Score (class 2): 0.3170\n",
+      "----------\n",
+      "Epoch: 26, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4007, Dice Score (class 1): 0.3061,Dice Score (class 2): 0.2244\n",
+      "----------\n",
+      "Epoch: 26, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4598, Dice Score (class 1): 0.0884,Dice Score (class 2): 0.1960\n",
+      "----------\n",
+      "Epoch: 27, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3824, Dice Score (class 1): 0.2838,Dice Score (class 2): 0.2146\n",
+      "----------\n",
+      "Epoch: 27, Phase: validate, epoch loss: 0.0100, Dice Score (class 0): 0.4430, Dice Score (class 1): 0.1496,Dice Score (class 2): 0.2968\n",
+      "----------\n",
+      "Epoch: 28, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3903, Dice Score (class 1): 0.3193,Dice Score (class 2): 0.2287\n",
+      "----------\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch: 28, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.3152, Dice Score (class 1): 0.1829,Dice Score (class 2): 0.2807\n",
+      "----------\n",
+      "Epoch: 29, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3628, Dice Score (class 1): 0.2731,Dice Score (class 2): 0.2147\n",
+      "----------\n",
+      "Epoch: 29, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4576, Dice Score (class 1): 0.1791,Dice Score (class 2): 0.2617\n",
+      "----------\n",
+      "Epoch: 30, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3856, Dice Score (class 1): 0.3332,Dice Score (class 2): 0.2106\n",
+      "----------\n",
+      "Epoch: 30, Phase: validate, epoch loss: 0.0088, Dice Score (class 0): 0.4605, Dice Score (class 1): 0.2360,Dice Score (class 2): 0.2981\n",
+      "----------\n",
+      "Epoch: 31, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.4015, Dice Score (class 1): 0.3838,Dice Score (class 2): 0.2237\n",
+      "----------\n",
+      "Epoch: 31, Phase: validate, epoch loss: 0.0103, Dice Score (class 0): 0.4688, Dice Score (class 1): 0.4179,Dice Score (class 2): 0.3486\n",
+      "----------\n",
+      "Epoch: 32, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.3904, Dice Score (class 1): 0.3807,Dice Score (class 2): 0.2404\n",
+      "----------\n",
+      "Epoch: 32, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4480, Dice Score (class 1): 0.3601,Dice Score (class 2): 0.2842\n",
+      "----------\n"
+     ]
+    },
+    {
+     "ename": "KeyboardInterrupt",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-80-d560073d377c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m                                                      \u001b[0mdataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_sizes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'new_data_unet_exp_noisy_1'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m                                                      \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m                                                      dice_loss = dice_loss_3,noisy_labels = True)\n\u001b[0m",
+      "\u001b[0;32m<ipython-input-24-c237342f9bb3>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, optimizer, dataloader, data_sizes, batch_size, name, num_epochs, verbose, dice_loss, noisy_labels)\u001b[0m\n\u001b[1;32m     24\u001b[0m                 \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdataloader\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     27\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     28\u001b[0m                 \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegT\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    177\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m             \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    180\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    181\u001b[0m                 \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    177\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m             \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    180\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    181\u001b[0m                 \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m<ipython-input-3-676504d23560>\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m     47\u001b[0m                 \u001b[0mfa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfliplr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m                 \u001b[0msegment_T\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_T\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m                 \u001b[0msegment_F\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_F\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m                 \u001b[0msegment_P\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_P\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mrotate\u001b[0;34m(image, angle, resize, center, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m    298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    299\u001b[0m     return warp(image, tform, output_shape=output_shape, order=order,\n\u001b[0;32m--> 300\u001b[0;31m                 mode=mode, cval=cval, clip=clip, preserve_range=preserve_range)\n\u001b[0m\u001b[1;32m    301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    302\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mwarp\u001b[0;34m(image, inverse_map, map_args, output_shape, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m    767\u001b[0m                 warped = _warp_fast(image, matrix,\n\u001b[1;32m    768\u001b[0m                                  \u001b[0moutput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_shape\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 769\u001b[0;31m                                  order=order, mode=mode, cval=cval)\n\u001b[0m\u001b[1;32m    770\u001b[0m             \u001b[0;32melif\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    771\u001b[0m                 \u001b[0mdims\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32mskimage/transform/_warps_cy.pyx\u001b[0m in \u001b[0;36mskimage.transform._warps_cy._warp_fast\u001b[0;34m()\u001b[0m\n",
+      "\u001b[0;32m/share/apps/python3/3.6.3/intel/lib/python3.6/site-packages/numpy-1.13.3-py3.6-linux-x86_64.egg/numpy/core/numeric.py\u001b[0m in \u001b[0;36masarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m    461\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    464\u001b[0m     \"\"\"Convert the input to an array.\n\u001b[1;32m    465\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+     ]
+    }
+   ],
+   "source": [
+    "unet_exp_noisy, loss_hist_unet_exp, dc_hist_0_unet_exp, \\\n",
+    "dc_hist_1_unet_exp, dc_hist_2_unet_exp = train_model(unet_exp_noisy, optimizer_unet_exp_noisy,\n",
+    "                                                     dataloader,data_sizes,5,'new_data_unet_exp_noisy_1',\n",
+    "                                                     num_epochs= 50, verbose = True, \n",
+    "                                                     dice_loss = dice_loss_3,noisy_labels = True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "# torch.save(model_gen_dilated_l4_n2_new_data_dp,'new_data_dilated_net_l4_n2_nd_dp_1')\n",
+    "pickle.dump(loss_hist_unet_exp, open('loss_hist_new_data_unet_exp_noisy_1','wb'))\n",
+    "pickle.dump(dc_hist_0_unet_exp, open('dc_hist_0_new_data_unet_exp_noisy_1','wb'))\n",
+    "pickle.dump(dc_hist_1_unet_exp, open('dc_hist_1_new_data_unet_exp_noisy_1','wb'))\n",
+    "pickle.dump(dc_hist_2_unet_exp, open('dc_hist_2_new_data_unet_exp_noisy_1','wb'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true,
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "plot_hist(loss_hist_unet_exp,'Loss')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true,
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "evaluate(unet_exp_noisy, dataloader, data_sizes, 5, 'validate', dice_loss=dice_loss_3, noisy_labels = True)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}