520 lines (519 with data), 31.9 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n",
"We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a way to segment images based on noisy ground truth"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"import torch.optim as optim\n",
"import torchvision\n",
"from torchvision import datasets, models\n",
"from torchvision import transforms as T\n",
"from torch.utils.data import DataLoader, Dataset\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"import time\n",
"import pandas as pd\n",
"from skimage import io, transform\n",
"import matplotlib.image as mpimg\n",
"from PIL import Image\n",
"from sklearn.metrics import roc_auc_score\n",
"import torch.nn.functional as F\n",
"import scipy\n",
"import random\n",
"import pickle\n",
"import scipy.io as sio\n",
"import itertools\n",
"from scipy.ndimage.interpolation import shift\n",
"import copy\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"%matplotlib inline\n",
"plt.ion()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Dataloader Class and other utilities"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from dataloader_2d import *\n",
"from dataloader_3d import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Build Data loader objects"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"train_path = '/beegfs/ark576/new_knee_data/train'\n",
"val_path = '/beegfs/ark576/new_knee_data/val'\n",
"test_path = '/beegfs/ark576/new_knee_data/test'\n",
"\n",
"train_file_names = sorted(pickle.load(open(train_path + '/train_file_names.p','rb')))\n",
"val_file_names = sorted(pickle.load(open(val_path + '/val_file_names.p','rb')))\n",
"test_file_names = sorted(pickle.load(open(test_path + '/test_file_names.p','rb')))\n",
"\n",
"transformed_dataset = {'train': KneeMRIDataset(train_path,train_file_names, train_data= True, flipping=False, normalize= True),\n",
" 'validate': KneeMRIDataset(val_path,val_file_names, normalize= True),\n",
" 'test': KneeMRIDataset(test_path,test_file_names, normalize= True)\n",
" }\n",
"\n",
"dataloader = {x: DataLoader(transformed_dataset[x], batch_size=5,\n",
" shuffle=True, num_workers=0) for x in ['train', 'validate','test']}\n",
"data_sizes ={x: len(transformed_dataset[x]) for x in ['train', 'validate','test']}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"im, seg_F, seg_P, seg_T,_ = next(iter(dataloader['train']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Find Max and min values of Images (all 7 contrasts), of Fractional Anisotropy maps and of Mean Diffusivity maps for image normalization"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"min_fa = np.inf\n",
"min_md = np.inf\n",
"min_image = np.inf\n",
"max_fa = 0\n",
"max_md = 0\n",
"max_image = 0\n",
"for data in dataloader['train']:\n",
" if min_fa > torch.min(data[0][:,7,:,:]):\n",
" min_fa = torch.min(data[0][:,7,:,:])\n",
" if min_md > torch.min(data[0][:,8,:,:]):\n",
" min_md = torch.min(data[0][:,8:,:])\n",
" if min_image > torch.min(data[0][:,:7,:,:]):\n",
" min_image = torch.min(data[0][:,:7,:,:])\n",
" if max_fa < torch.max(data[0][:,7,:,:]):\n",
" max_fa = torch.max(data[0][:,7,:,:])\n",
" if max_md < torch.max(data[0][:,8,:,:]):\n",
" max_md = torch.max(data[0][:,8,:,:])\n",
" if max_image < torch.max(data[0][:,:7,:,:]):\n",
" max_image = torch.max(data[0][:,:7,:,:])\n",
"norm_values = (max_image, min_image, max_fa, min_fa, max_md, min_md)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Models"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from unet_3d import *\n",
"from unet_basic_dilated import *\n",
"from vnet import *\n",
"from ensemble_model import *"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"seg_sum = torch.zeros(3)\n",
"for i, data in enumerate(dataloader['train']):\n",
" input, segF, segP, segT,_ = data\n",
" seg_sum[0] += torch.sum(segF)\n",
" seg_sum[1] += torch.sum(segP)\n",
" seg_sum[2] += torch.sum(segT)\n",
"mean_s_sum = seg_sum/i"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Loss functions and all other utility functions like functions for saving models, for visualizing images, etc."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import all the Training and evaluate functions to evaluate the models"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from train_2d import *\n",
"from train_3d import *\n",
"from train_ensemble import *\n",
"from evaluate_2d import *\n",
"from evaluate_3d import *\n",
"from evaluate_ensemble import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n",
"We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a way to segment images based on noisy ground truth"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"unet_exp_noisy = Unet_dilated_small(9,4,int_var=40,dilated=False).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"410524"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"count_parameters(unet_exp_noisy)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"optimizer_unet_exp_noisy = optim.Adam(unet_exp_noisy.parameters(),lr = 1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0, Phase: train, epoch loss: 0.0069, Dice Score (class 0): 0.4548, Dice Score (class 1): 0.3627,Dice Score (class 2): 0.2732\n",
"----------\n",
"Epoch: 0, Phase: validate, epoch loss: 0.0107, Dice Score (class 0): 0.4652, Dice Score (class 1): 0.2242,Dice Score (class 2): 0.2691\n",
"----------\n",
"Epoch: 1, Phase: train, epoch loss: 0.0066, Dice Score (class 0): 0.4464, Dice Score (class 1): 0.3631,Dice Score (class 2): 0.2697\n",
"----------\n",
"Epoch: 1, Phase: validate, epoch loss: 0.0116, Dice Score (class 0): 0.4738, Dice Score (class 1): 0.1400,Dice Score (class 2): 0.2687\n",
"----------\n",
"Epoch: 2, Phase: train, epoch loss: 0.0063, Dice Score (class 0): 0.4417, Dice Score (class 1): 0.3004,Dice Score (class 2): 0.2691\n",
"----------\n",
"Epoch: 2, Phase: validate, epoch loss: 0.0083, Dice Score (class 0): 0.4532, Dice Score (class 1): 0.1267,Dice Score (class 2): 0.2386\n",
"----------\n",
"Epoch: 3, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4365, Dice Score (class 1): 0.3346,Dice Score (class 2): 0.2603\n",
"----------\n",
"Epoch: 3, Phase: validate, epoch loss: 0.0090, Dice Score (class 0): 0.4624, Dice Score (class 1): 0.1272,Dice Score (class 2): 0.2762\n",
"----------\n",
"Epoch: 4, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4264, Dice Score (class 1): 0.2908,Dice Score (class 2): 0.2789\n",
"----------\n",
"Epoch: 4, Phase: validate, epoch loss: 0.0106, Dice Score (class 0): 0.4784, Dice Score (class 1): 0.0917,Dice Score (class 2): 0.2196\n",
"----------\n",
"Epoch: 5, Phase: train, epoch loss: 0.0053, Dice Score (class 0): 0.4341, Dice Score (class 1): 0.3044,Dice Score (class 2): 0.2540\n",
"----------\n",
"Epoch: 5, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4828, Dice Score (class 1): 0.1812,Dice Score (class 2): 0.2497\n",
"----------\n",
"Epoch: 6, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.4285, Dice Score (class 1): 0.2792,Dice Score (class 2): 0.2567\n",
"----------\n",
"Epoch: 6, Phase: validate, epoch loss: 0.0081, Dice Score (class 0): 0.4304, Dice Score (class 1): 0.0803,Dice Score (class 2): 0.2250\n",
"----------\n",
"Epoch: 7, Phase: train, epoch loss: 0.0049, Dice Score (class 0): 0.4271, Dice Score (class 1): 0.3015,Dice Score (class 2): 0.2550\n",
"----------\n",
"Epoch: 7, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4339, Dice Score (class 1): 0.1199,Dice Score (class 2): 0.2402\n",
"----------\n",
"Epoch: 8, Phase: train, epoch loss: 0.0046, Dice Score (class 0): 0.4336, Dice Score (class 1): 0.3101,Dice Score (class 2): 0.2492\n",
"----------\n",
"Epoch: 8, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4146, Dice Score (class 1): 0.0966,Dice Score (class 2): 0.2273\n",
"----------\n",
"Epoch: 9, Phase: train, epoch loss: 0.0047, Dice Score (class 0): 0.4042, Dice Score (class 1): 0.3042,Dice Score (class 2): 0.2402\n",
"----------\n",
"Epoch: 9, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4546, Dice Score (class 1): 0.0937,Dice Score (class 2): 0.2365\n",
"----------\n",
"Epoch: 10, Phase: train, epoch loss: 0.0048, Dice Score (class 0): 0.4017, Dice Score (class 1): 0.2660,Dice Score (class 2): 0.2579\n",
"----------\n",
"Epoch: 10, Phase: validate, epoch loss: 0.0093, Dice Score (class 0): 0.3297, Dice Score (class 1): 0.1008,Dice Score (class 2): 0.3376\n",
"----------\n",
"Epoch: 11, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.3732, Dice Score (class 1): 0.2240,Dice Score (class 2): 0.2164\n",
"----------\n",
"Epoch: 11, Phase: validate, epoch loss: 0.0080, Dice Score (class 0): 0.3507, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.1992\n",
"----------\n",
"Epoch: 12, Phase: train, epoch loss: 0.0044, Dice Score (class 0): 0.4045, Dice Score (class 1): 0.2663,Dice Score (class 2): 0.2398\n",
"----------\n",
"Epoch: 12, Phase: validate, epoch loss: 0.0054, Dice Score (class 0): 0.3737, Dice Score (class 1): 0.1181,Dice Score (class 2): 0.2016\n",
"----------\n",
"Epoch: 13, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.4121, Dice Score (class 1): 0.2953,Dice Score (class 2): 0.2487\n",
"----------\n",
"Epoch: 13, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4482, Dice Score (class 1): 0.1157,Dice Score (class 2): 0.2612\n",
"----------\n",
"Epoch: 14, Phase: train, epoch loss: 0.0037, Dice Score (class 0): 0.4235, Dice Score (class 1): 0.3123,Dice Score (class 2): 0.2620\n",
"----------\n",
"Epoch: 14, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4270, Dice Score (class 1): 0.0954,Dice Score (class 2): 0.2708\n",
"----------\n",
"Epoch: 15, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.4144, Dice Score (class 1): 0.2824,Dice Score (class 2): 0.2676\n",
"----------\n",
"Epoch: 15, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4503, Dice Score (class 1): 0.1200,Dice Score (class 2): 0.2609\n",
"----------\n",
"Epoch: 16, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4183, Dice Score (class 1): 0.2898,Dice Score (class 2): 0.2592\n",
"----------\n",
"Epoch: 16, Phase: validate, epoch loss: 0.0061, Dice Score (class 0): 0.4537, Dice Score (class 1): 0.1223,Dice Score (class 2): 0.3014\n",
"----------\n",
"Epoch: 17, Phase: train, epoch loss: 0.0035, Dice Score (class 0): 0.4004, Dice Score (class 1): 0.2839,Dice Score (class 2): 0.2320\n",
"----------\n",
"Epoch: 17, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4409, Dice Score (class 1): 0.1723,Dice Score (class 2): 0.3229\n",
"----------\n",
"Epoch: 18, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4095, Dice Score (class 1): 0.2638,Dice Score (class 2): 0.2377\n",
"----------\n",
"Epoch: 18, Phase: validate, epoch loss: 0.0124, Dice Score (class 0): 0.4778, Dice Score (class 1): 0.0729,Dice Score (class 2): 0.2672\n",
"----------\n",
"Epoch: 19, Phase: train, epoch loss: 0.0032, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.3173,Dice Score (class 2): 0.2508\n",
"----------\n",
"Epoch: 19, Phase: validate, epoch loss: 0.0087, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.2357,Dice Score (class 2): 0.3106\n",
"----------\n",
"Epoch: 20, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.4090, Dice Score (class 1): 0.3206,Dice Score (class 2): 0.2621\n",
"----------\n",
"Epoch: 20, Phase: validate, epoch loss: 0.0053, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.1603,Dice Score (class 2): 0.2504\n",
"----------\n",
"Epoch: 21, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.3843, Dice Score (class 1): 0.2829,Dice Score (class 2): 0.2470\n",
"----------\n",
"Epoch: 21, Phase: validate, epoch loss: 0.0112, Dice Score (class 0): 0.3247, Dice Score (class 1): 0.0295,Dice Score (class 2): 0.2692\n",
"----------\n",
"Epoch: 22, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.3750, Dice Score (class 1): 0.1967,Dice Score (class 2): 0.2209\n",
"----------\n",
"Epoch: 22, Phase: validate, epoch loss: 0.0075, Dice Score (class 0): 0.4240, Dice Score (class 1): 0.1111,Dice Score (class 2): 0.3162\n",
"----------\n",
"Epoch: 23, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.3883, Dice Score (class 1): 0.2905,Dice Score (class 2): 0.2495\n",
"----------\n",
"Epoch: 23, Phase: validate, epoch loss: 0.0079, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.2906\n",
"----------\n",
"Epoch: 24, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4021, Dice Score (class 1): 0.3212,Dice Score (class 2): 0.2533\n",
"----------\n",
"Epoch: 24, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4512, Dice Score (class 1): 0.2062,Dice Score (class 2): 0.3409\n",
"----------\n",
"Epoch: 25, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.4044, Dice Score (class 1): 0.3422,Dice Score (class 2): 0.2634\n",
"----------\n",
"Epoch: 25, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4467, Dice Score (class 1): 0.3159,Dice Score (class 2): 0.3170\n",
"----------\n",
"Epoch: 26, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4007, Dice Score (class 1): 0.3061,Dice Score (class 2): 0.2244\n",
"----------\n",
"Epoch: 26, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4598, Dice Score (class 1): 0.0884,Dice Score (class 2): 0.1960\n",
"----------\n",
"Epoch: 27, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3824, Dice Score (class 1): 0.2838,Dice Score (class 2): 0.2146\n",
"----------\n",
"Epoch: 27, Phase: validate, epoch loss: 0.0100, Dice Score (class 0): 0.4430, Dice Score (class 1): 0.1496,Dice Score (class 2): 0.2968\n",
"----------\n",
"Epoch: 28, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3903, Dice Score (class 1): 0.3193,Dice Score (class 2): 0.2287\n",
"----------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 28, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.3152, Dice Score (class 1): 0.1829,Dice Score (class 2): 0.2807\n",
"----------\n",
"Epoch: 29, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3628, Dice Score (class 1): 0.2731,Dice Score (class 2): 0.2147\n",
"----------\n",
"Epoch: 29, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4576, Dice Score (class 1): 0.1791,Dice Score (class 2): 0.2617\n",
"----------\n",
"Epoch: 30, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3856, Dice Score (class 1): 0.3332,Dice Score (class 2): 0.2106\n",
"----------\n",
"Epoch: 30, Phase: validate, epoch loss: 0.0088, Dice Score (class 0): 0.4605, Dice Score (class 1): 0.2360,Dice Score (class 2): 0.2981\n",
"----------\n",
"Epoch: 31, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.4015, Dice Score (class 1): 0.3838,Dice Score (class 2): 0.2237\n",
"----------\n",
"Epoch: 31, Phase: validate, epoch loss: 0.0103, Dice Score (class 0): 0.4688, Dice Score (class 1): 0.4179,Dice Score (class 2): 0.3486\n",
"----------\n",
"Epoch: 32, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.3904, Dice Score (class 1): 0.3807,Dice Score (class 2): 0.2404\n",
"----------\n",
"Epoch: 32, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4480, Dice Score (class 1): 0.3601,Dice Score (class 2): 0.2842\n",
"----------\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-80-d560073d377c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_sizes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'new_data_unet_exp_noisy_1'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m dice_loss = dice_loss_3,noisy_labels = True)\n\u001b[0m",
"\u001b[0;32m<ipython-input-24-c237342f9bb3>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, optimizer, dataloader, data_sizes, batch_size, name, num_epochs, verbose, dice_loss, noisy_labels)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdataloader\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegT\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-3-676504d23560>\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mfa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfliplr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0msegment_T\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_T\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0msegment_F\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_F\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0msegment_P\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_P\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mrotate\u001b[0;34m(image, angle, resize, center, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m 298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m return warp(image, tform, output_shape=output_shape, order=order,\n\u001b[0;32m--> 300\u001b[0;31m mode=mode, cval=cval, clip=clip, preserve_range=preserve_range)\n\u001b[0m\u001b[1;32m 301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mwarp\u001b[0;34m(image, inverse_map, map_args, output_shape, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m 767\u001b[0m warped = _warp_fast(image, matrix,\n\u001b[1;32m 768\u001b[0m \u001b[0moutput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_shape\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 769\u001b[0;31m order=order, mode=mode, cval=cval)\n\u001b[0m\u001b[1;32m 770\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 771\u001b[0m \u001b[0mdims\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32mskimage/transform/_warps_cy.pyx\u001b[0m in \u001b[0;36mskimage.transform._warps_cy._warp_fast\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m/share/apps/python3/3.6.3/intel/lib/python3.6/site-packages/numpy-1.13.3-py3.6-linux-x86_64.egg/numpy/core/numeric.py\u001b[0m in \u001b[0;36masarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 464\u001b[0m \"\"\"Convert the input to an array.\n\u001b[1;32m 465\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"unet_exp_noisy, loss_hist_unet_exp, dc_hist_0_unet_exp, \\\n",
"dc_hist_1_unet_exp, dc_hist_2_unet_exp = train_model(unet_exp_noisy, optimizer_unet_exp_noisy,\n",
" dataloader,data_sizes,5,'new_data_unet_exp_noisy_1',\n",
" num_epochs= 50, verbose = True, \n",
" dice_loss = dice_loss_3,noisy_labels = True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# torch.save(model_gen_dilated_l4_n2_new_data_dp,'new_data_dilated_net_l4_n2_nd_dp_1')\n",
"pickle.dump(loss_hist_unet_exp, open('loss_hist_new_data_unet_exp_noisy_1','wb'))\n",
"pickle.dump(dc_hist_0_unet_exp, open('dc_hist_0_new_data_unet_exp_noisy_1','wb'))\n",
"pickle.dump(dc_hist_1_unet_exp, open('dc_hist_1_new_data_unet_exp_noisy_1','wb'))\n",
"pickle.dump(dc_hist_2_unet_exp, open('dc_hist_2_new_data_unet_exp_noisy_1','wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"plot_hist(loss_hist_unet_exp,'Loss')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"evaluate(unet_exp_noisy, dataloader, data_sizes, 5, 'validate', dice_loss=dice_loss_3, noisy_labels = True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}