--- a +++ b/Experimentations/Exp14-Label smoothing and noisy ground truth.ipynb @@ -0,0 +1,519 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n", + "We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a way to segment images based on noisy ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.autograd import Variable\n", + "import torch.optim as optim\n", + "import torchvision\n", + "from torchvision import datasets, models\n", + "from torchvision import transforms as T\n", + "from torch.utils.data import DataLoader, Dataset\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "import time\n", + "import pandas as pd\n", + "from skimage import io, transform\n", + "import matplotlib.image as mpimg\n", + "from PIL import Image\n", + "from sklearn.metrics import roc_auc_score\n", + "import torch.nn.functional as F\n", + "import scipy\n", + "import random\n", + "import pickle\n", + "import scipy.io as sio\n", + "import itertools\n", + "from scipy.ndimage.interpolation import shift\n", + "import copy\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "%matplotlib inline\n", + "plt.ion()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Import Dataloader Class and other utilities" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from dataloader_2d import *\n", + "from dataloader_3d import *" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Build Data loader objects" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "train_path = '/beegfs/ark576/new_knee_data/train'\n", + "val_path = '/beegfs/ark576/new_knee_data/val'\n", + "test_path = '/beegfs/ark576/new_knee_data/test'\n", + "\n", + "train_file_names = sorted(pickle.load(open(train_path + '/train_file_names.p','rb')))\n", + "val_file_names = sorted(pickle.load(open(val_path + '/val_file_names.p','rb')))\n", + "test_file_names = sorted(pickle.load(open(test_path + '/test_file_names.p','rb')))\n", + "\n", + "transformed_dataset = {'train': KneeMRIDataset(train_path,train_file_names, train_data= True, flipping=False, normalize= True),\n", + " 'validate': KneeMRIDataset(val_path,val_file_names, normalize= True),\n", + " 'test': KneeMRIDataset(test_path,test_file_names, normalize= True)\n", + " }\n", + "\n", + "dataloader = {x: DataLoader(transformed_dataset[x], batch_size=5,\n", + " shuffle=True, num_workers=0) for x in ['train', 'validate','test']}\n", + "data_sizes ={x: len(transformed_dataset[x]) for x in ['train', 'validate','test']}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "im, seg_F, seg_P, seg_T,_ = next(iter(dataloader['train']))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Find Max and min values of Images (all 7 contrasts), of Fractional Anisotropy maps and of Mean Diffusivity maps for image normalization" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "min_fa = np.inf\n", + "min_md = np.inf\n", + "min_image = np.inf\n", + "max_fa = 0\n", + "max_md = 0\n", + "max_image = 0\n", + "for data in dataloader['train']:\n", + " if min_fa > torch.min(data[0][:,7,:,:]):\n", + " min_fa = torch.min(data[0][:,7,:,:])\n", + " if min_md > torch.min(data[0][:,8,:,:]):\n", + " min_md = torch.min(data[0][:,8:,:])\n", + " if min_image > torch.min(data[0][:,:7,:,:]):\n", + " min_image = torch.min(data[0][:,:7,:,:])\n", + " if max_fa < torch.max(data[0][:,7,:,:]):\n", + " max_fa = torch.max(data[0][:,7,:,:])\n", + " if max_md < torch.max(data[0][:,8,:,:]):\n", + " max_md = torch.max(data[0][:,8,:,:])\n", + " if max_image < torch.max(data[0][:,:7,:,:]):\n", + " max_image = torch.max(data[0][:,:7,:,:])\n", + "norm_values = (max_image, min_image, max_fa, min_fa, max_md, min_md)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Import Models" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from unet_3d import *\n", + "from unet_basic_dilated import *\n", + "from vnet import *\n", + "from ensemble_model import *" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "seg_sum = torch.zeros(3)\n", + "for i, data in enumerate(dataloader['train']):\n", + " input, segF, segP, segT,_ = data\n", + " seg_sum[0] += torch.sum(segF)\n", + " seg_sum[1] += torch.sum(segP)\n", + " seg_sum[2] += torch.sum(segT)\n", + "mean_s_sum = seg_sum/i" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import Loss functions and all other utility functions like functions for saving models, for visualizing images, etc." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from utils import *" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import all the Training and evaluate functions to evaluate the models" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from train_2d import *\n", + "from train_3d import *\n", + "from train_ensemble import *\n", + "from evaluate_2d import *\n", + "from evaluate_3d import *\n", + "from evaluate_ensemble import *" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n", + "We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a way to segment images based on noisy ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "unet_exp_noisy = Unet_dilated_small(9,4,int_var=40,dilated=False).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "410524" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "count_parameters(unet_exp_noisy)" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "optimizer_unet_exp_noisy = optim.Adam(unet_exp_noisy.parameters(),lr = 1e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 0, Phase: train, epoch loss: 0.0069, Dice Score (class 0): 0.4548, Dice Score (class 1): 0.3627,Dice Score (class 2): 0.2732\n", + "----------\n", + "Epoch: 0, Phase: validate, epoch loss: 0.0107, Dice Score (class 0): 0.4652, Dice Score (class 1): 0.2242,Dice Score (class 2): 0.2691\n", + "----------\n", + "Epoch: 1, Phase: train, epoch loss: 0.0066, Dice Score (class 0): 0.4464, Dice Score (class 1): 0.3631,Dice Score (class 2): 0.2697\n", + "----------\n", + "Epoch: 1, Phase: validate, epoch loss: 0.0116, Dice Score (class 0): 0.4738, Dice Score (class 1): 0.1400,Dice Score (class 2): 0.2687\n", + "----------\n", + "Epoch: 2, Phase: train, epoch loss: 0.0063, Dice Score (class 0): 0.4417, Dice Score (class 1): 0.3004,Dice Score (class 2): 0.2691\n", + "----------\n", + "Epoch: 2, Phase: validate, epoch loss: 0.0083, Dice Score (class 0): 0.4532, Dice Score (class 1): 0.1267,Dice Score (class 2): 0.2386\n", + "----------\n", + "Epoch: 3, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4365, Dice Score (class 1): 0.3346,Dice Score (class 2): 0.2603\n", + "----------\n", + "Epoch: 3, Phase: validate, epoch loss: 0.0090, Dice Score (class 0): 0.4624, Dice Score (class 1): 0.1272,Dice Score (class 2): 0.2762\n", + "----------\n", + "Epoch: 4, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4264, Dice Score (class 1): 0.2908,Dice Score (class 2): 0.2789\n", + "----------\n", + "Epoch: 4, Phase: validate, epoch loss: 0.0106, Dice Score (class 0): 0.4784, Dice Score (class 1): 0.0917,Dice Score (class 2): 0.2196\n", + "----------\n", + "Epoch: 5, Phase: train, epoch loss: 0.0053, Dice Score (class 0): 0.4341, Dice Score (class 1): 0.3044,Dice Score (class 2): 0.2540\n", + "----------\n", + "Epoch: 5, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4828, Dice Score (class 1): 0.1812,Dice Score (class 2): 0.2497\n", + "----------\n", + "Epoch: 6, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.4285, Dice Score (class 1): 0.2792,Dice Score (class 2): 0.2567\n", + "----------\n", + "Epoch: 6, Phase: validate, epoch loss: 0.0081, Dice Score (class 0): 0.4304, Dice Score (class 1): 0.0803,Dice Score (class 2): 0.2250\n", + "----------\n", + "Epoch: 7, Phase: train, epoch loss: 0.0049, Dice Score (class 0): 0.4271, Dice Score (class 1): 0.3015,Dice Score (class 2): 0.2550\n", + "----------\n", + "Epoch: 7, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4339, Dice Score (class 1): 0.1199,Dice Score (class 2): 0.2402\n", + "----------\n", + "Epoch: 8, Phase: train, epoch loss: 0.0046, Dice Score (class 0): 0.4336, Dice Score (class 1): 0.3101,Dice Score (class 2): 0.2492\n", + "----------\n", + "Epoch: 8, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4146, Dice Score (class 1): 0.0966,Dice Score (class 2): 0.2273\n", + "----------\n", + "Epoch: 9, Phase: train, epoch loss: 0.0047, Dice Score (class 0): 0.4042, Dice Score (class 1): 0.3042,Dice Score (class 2): 0.2402\n", + "----------\n", + "Epoch: 9, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4546, Dice Score (class 1): 0.0937,Dice Score (class 2): 0.2365\n", + "----------\n", + "Epoch: 10, Phase: train, epoch loss: 0.0048, Dice Score (class 0): 0.4017, Dice Score (class 1): 0.2660,Dice Score (class 2): 0.2579\n", + "----------\n", + "Epoch: 10, Phase: validate, epoch loss: 0.0093, Dice Score (class 0): 0.3297, Dice Score (class 1): 0.1008,Dice Score (class 2): 0.3376\n", + "----------\n", + "Epoch: 11, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.3732, Dice Score (class 1): 0.2240,Dice Score (class 2): 0.2164\n", + "----------\n", + "Epoch: 11, Phase: validate, epoch loss: 0.0080, Dice Score (class 0): 0.3507, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.1992\n", + "----------\n", + "Epoch: 12, Phase: train, epoch loss: 0.0044, Dice Score (class 0): 0.4045, Dice Score (class 1): 0.2663,Dice Score (class 2): 0.2398\n", + "----------\n", + "Epoch: 12, Phase: validate, epoch loss: 0.0054, Dice Score (class 0): 0.3737, Dice Score (class 1): 0.1181,Dice Score (class 2): 0.2016\n", + "----------\n", + "Epoch: 13, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.4121, Dice Score (class 1): 0.2953,Dice Score (class 2): 0.2487\n", + "----------\n", + "Epoch: 13, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4482, Dice Score (class 1): 0.1157,Dice Score (class 2): 0.2612\n", + "----------\n", + "Epoch: 14, Phase: train, epoch loss: 0.0037, Dice Score (class 0): 0.4235, Dice Score (class 1): 0.3123,Dice Score (class 2): 0.2620\n", + "----------\n", + "Epoch: 14, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4270, Dice Score (class 1): 0.0954,Dice Score (class 2): 0.2708\n", + "----------\n", + "Epoch: 15, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.4144, Dice Score (class 1): 0.2824,Dice Score (class 2): 0.2676\n", + "----------\n", + "Epoch: 15, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4503, Dice Score (class 1): 0.1200,Dice Score (class 2): 0.2609\n", + "----------\n", + "Epoch: 16, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4183, Dice Score (class 1): 0.2898,Dice Score (class 2): 0.2592\n", + "----------\n", + "Epoch: 16, Phase: validate, epoch loss: 0.0061, Dice Score (class 0): 0.4537, Dice Score (class 1): 0.1223,Dice Score (class 2): 0.3014\n", + "----------\n", + "Epoch: 17, Phase: train, epoch loss: 0.0035, Dice Score (class 0): 0.4004, Dice Score (class 1): 0.2839,Dice Score (class 2): 0.2320\n", + "----------\n", + "Epoch: 17, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4409, Dice Score (class 1): 0.1723,Dice Score (class 2): 0.3229\n", + "----------\n", + "Epoch: 18, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4095, Dice Score (class 1): 0.2638,Dice Score (class 2): 0.2377\n", + "----------\n", + "Epoch: 18, Phase: validate, epoch loss: 0.0124, Dice Score (class 0): 0.4778, Dice Score (class 1): 0.0729,Dice Score (class 2): 0.2672\n", + "----------\n", + "Epoch: 19, Phase: train, epoch loss: 0.0032, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.3173,Dice Score (class 2): 0.2508\n", + "----------\n", + "Epoch: 19, Phase: validate, epoch loss: 0.0087, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.2357,Dice Score (class 2): 0.3106\n", + "----------\n", + "Epoch: 20, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.4090, Dice Score (class 1): 0.3206,Dice Score (class 2): 0.2621\n", + "----------\n", + "Epoch: 20, Phase: validate, epoch loss: 0.0053, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.1603,Dice Score (class 2): 0.2504\n", + "----------\n", + "Epoch: 21, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.3843, Dice Score (class 1): 0.2829,Dice Score (class 2): 0.2470\n", + "----------\n", + "Epoch: 21, Phase: validate, epoch loss: 0.0112, Dice Score (class 0): 0.3247, Dice Score (class 1): 0.0295,Dice Score (class 2): 0.2692\n", + "----------\n", + "Epoch: 22, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.3750, Dice Score (class 1): 0.1967,Dice Score (class 2): 0.2209\n", + "----------\n", + "Epoch: 22, Phase: validate, epoch loss: 0.0075, Dice Score (class 0): 0.4240, Dice Score (class 1): 0.1111,Dice Score (class 2): 0.3162\n", + "----------\n", + "Epoch: 23, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.3883, Dice Score (class 1): 0.2905,Dice Score (class 2): 0.2495\n", + "----------\n", + "Epoch: 23, Phase: validate, epoch loss: 0.0079, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.2906\n", + "----------\n", + "Epoch: 24, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4021, Dice Score (class 1): 0.3212,Dice Score (class 2): 0.2533\n", + "----------\n", + "Epoch: 24, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4512, Dice Score (class 1): 0.2062,Dice Score (class 2): 0.3409\n", + "----------\n", + "Epoch: 25, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.4044, Dice Score (class 1): 0.3422,Dice Score (class 2): 0.2634\n", + "----------\n", + "Epoch: 25, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4467, Dice Score (class 1): 0.3159,Dice Score (class 2): 0.3170\n", + "----------\n", + "Epoch: 26, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4007, Dice Score (class 1): 0.3061,Dice Score (class 2): 0.2244\n", + "----------\n", + "Epoch: 26, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4598, Dice Score (class 1): 0.0884,Dice Score (class 2): 0.1960\n", + "----------\n", + "Epoch: 27, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3824, Dice Score (class 1): 0.2838,Dice Score (class 2): 0.2146\n", + "----------\n", + "Epoch: 27, Phase: validate, epoch loss: 0.0100, Dice Score (class 0): 0.4430, Dice Score (class 1): 0.1496,Dice Score (class 2): 0.2968\n", + "----------\n", + "Epoch: 28, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3903, Dice Score (class 1): 0.3193,Dice Score (class 2): 0.2287\n", + "----------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 28, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.3152, Dice Score (class 1): 0.1829,Dice Score (class 2): 0.2807\n", + "----------\n", + "Epoch: 29, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3628, Dice Score (class 1): 0.2731,Dice Score (class 2): 0.2147\n", + "----------\n", + "Epoch: 29, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4576, Dice Score (class 1): 0.1791,Dice Score (class 2): 0.2617\n", + "----------\n", + "Epoch: 30, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3856, Dice Score (class 1): 0.3332,Dice Score (class 2): 0.2106\n", + "----------\n", + "Epoch: 30, Phase: validate, epoch loss: 0.0088, Dice Score (class 0): 0.4605, Dice Score (class 1): 0.2360,Dice Score (class 2): 0.2981\n", + "----------\n", + "Epoch: 31, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.4015, Dice Score (class 1): 0.3838,Dice Score (class 2): 0.2237\n", + "----------\n", + "Epoch: 31, Phase: validate, epoch loss: 0.0103, Dice Score (class 0): 0.4688, Dice Score (class 1): 0.4179,Dice Score (class 2): 0.3486\n", + "----------\n", + "Epoch: 32, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.3904, Dice Score (class 1): 0.3807,Dice Score (class 2): 0.2404\n", + "----------\n", + "Epoch: 32, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4480, Dice Score (class 1): 0.3601,Dice Score (class 2): 0.2842\n", + "----------\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-80-d560073d377c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_sizes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'new_data_unet_exp_noisy_1'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m dice_loss = dice_loss_3,noisy_labels = True)\n\u001b[0m", + "\u001b[0;32m<ipython-input-24-c237342f9bb3>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, optimizer, dataloader, data_sizes, batch_size, name, num_epochs, verbose, dice_loss, noisy_labels)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdataloader\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegT\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m<ipython-input-3-676504d23560>\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mfa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfliplr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0msegment_T\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_T\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0msegment_F\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_F\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0msegment_P\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_P\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mrotate\u001b[0;34m(image, angle, resize, center, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m 298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m return warp(image, tform, output_shape=output_shape, order=order,\n\u001b[0;32m--> 300\u001b[0;31m mode=mode, cval=cval, clip=clip, preserve_range=preserve_range)\n\u001b[0m\u001b[1;32m 301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mwarp\u001b[0;34m(image, inverse_map, map_args, output_shape, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m 767\u001b[0m warped = _warp_fast(image, matrix,\n\u001b[1;32m 768\u001b[0m \u001b[0moutput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_shape\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 769\u001b[0;31m order=order, mode=mode, cval=cval)\n\u001b[0m\u001b[1;32m 770\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 771\u001b[0m \u001b[0mdims\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32mskimage/transform/_warps_cy.pyx\u001b[0m in \u001b[0;36mskimage.transform._warps_cy._warp_fast\u001b[0;34m()\u001b[0m\n", + "\u001b[0;32m/share/apps/python3/3.6.3/intel/lib/python3.6/site-packages/numpy-1.13.3-py3.6-linux-x86_64.egg/numpy/core/numeric.py\u001b[0m in \u001b[0;36masarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 464\u001b[0m \"\"\"Convert the input to an array.\n\u001b[1;32m 465\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "unet_exp_noisy, loss_hist_unet_exp, dc_hist_0_unet_exp, \\\n", + "dc_hist_1_unet_exp, dc_hist_2_unet_exp = train_model(unet_exp_noisy, optimizer_unet_exp_noisy,\n", + " dataloader,data_sizes,5,'new_data_unet_exp_noisy_1',\n", + " num_epochs= 50, verbose = True, \n", + " dice_loss = dice_loss_3,noisy_labels = True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# torch.save(model_gen_dilated_l4_n2_new_data_dp,'new_data_dilated_net_l4_n2_nd_dp_1')\n", + "pickle.dump(loss_hist_unet_exp, open('loss_hist_new_data_unet_exp_noisy_1','wb'))\n", + "pickle.dump(dc_hist_0_unet_exp, open('dc_hist_0_new_data_unet_exp_noisy_1','wb'))\n", + "pickle.dump(dc_hist_1_unet_exp, open('dc_hist_1_new_data_unet_exp_noisy_1','wb'))\n", + "pickle.dump(dc_hist_2_unet_exp, open('dc_hist_2_new_data_unet_exp_noisy_1','wb'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "scrolled": true + }, + "outputs": [], + "source": [ + "plot_hist(loss_hist_unet_exp,'Loss')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "scrolled": true + }, + "outputs": [], + "source": [ + "evaluate(unet_exp_noisy, dataloader, data_sizes, 5, 'validate', dice_loss=dice_loss_3, noisy_labels = True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}