--- a +++ b/train.py @@ -0,0 +1,420 @@ +import sys +sys.path.append('architectures/deeplab_3D/') +sys.path.append('architectures/unet_3D/') +sys.path.append('architectures/hrnet_3D/') +sys.path.append('architectures/experiment_nets_3D/') +sys.path.append('utils/') + +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import torch.optim as optim + +import numpy as np +import scipy.misc +import os +from tqdm import * +import random +from random import randint +from docopt import docopt + +import deeplab_resnet_3D +import unet_3D +import highresnet_3D +import exp_net_3D + +import lossF +import PP +import augmentations as AUG + +import nibabel as nib +import evalF as EF +import evalFP as EFP + + +docstr = """Write something here + +Usage: + train.py [options] + +Options: + -h, --help Print this message + --archId=<int> Architecture to run, 0 is DeepLab 3D, 1 is U-net3D, 2 is HRNet [default: 2] + --trainMethod=<int> 0 is full image, 1 is by patches (random), 2 is by patches (center pixel) [default: 1] + --lossFunction=<str> Loss function name. 'dice' option available [default: dice] + --imgSize=<str> Image size [default: 200x200x100] + --mainFolderPath=<str> Main folder path [default: ../Data/MS2017b/] + --patchSize=<int> Size of the patch [default: 60] + --patchSizeStage0=<int> Size of the patch at stage 0 [default: 41] + --namePostfix=<str> Postfix of flair. i.e. to use FLAIR_s postfix is _s. This also determines the train file [default: _200x200x100orig] + --modelPath=<str> Path of model to continue training on [default: none] + --NoLabels=<int> The number of different labels in training data, including background [default: 2] + --maxIter=<int> Maximum number of iterations [default: 20000] + --maxIterStage0=<int> Maximum number of iterations for stage 0 training [default: -1] + -i, --iterSize=<int> Num iters to accumulate gradients over [default: 1] + --lr=<float> Learning Rate [default: 0.0001] + --gpu0=<int> GPU number [default: 0] + --useGPU=<int> Use GPU or not [default: 0] + --experiment=<str> Specify experiment instead to run. e.g. 1x1x1x1x1x1_1_0 means 1 dilations all 6 blocks, with priv, no ASPP [default: None] +""" +args = docopt(docstr, version='v0.1') +print(args) + +arch_id = int(args['--archId']) +train_method = int(args['--trainMethod']) +loss_name = args['--lossFunction'] +img_dims = np.array(args['--imgSize'].split('x'), dtype=np.int64) +main_folder_path = args['--mainFolderPath'] +patch_size = int(args['--patchSize']) + +postfix = args['--namePostfix'] +model_path = args['--modelPath'] +num_labels = int(args['--NoLabels']) +max_iter = int(args['--maxIter']) + +iter_size = int(args['--iterSize']) +base_lr = float(args['--lr']) +experiment = str(args['--experiment']) +gpu0 = int(args['--gpu0']) +useGPU = int(args['--useGPU']) +batch_size = 1 +#img_dims = [197, 233, 189] +list_path = main_folder_path + 'train' + postfix + '.txt' +print('READING from ', list_path) +img_type_path = 'pre/FLAIR' + postfix + '.nii.gz' +gt_type_path = 'wmh' + postfix + '.nii.gz' + + +patch_size_stage0 = int(args['--patchSizeStage0']) +max_iter_stage0 = int(args['--maxIterStage0']) + +iter_low = 1 +iter_high = max_iter + 1 + +if model_path != 'none': + iter_low = int(model_path.split('iter_')[-1].replace('.pth','')) + 1 + if iter_low >= iter_high: + print('Model already at ' + str(iter_low) + ' iterations. Change max iter size') + sys.exit() + +num_labels2 = 209 +#change to 0 to enable stage 0 patch learning + +if num_labels == 2: + onlyLesions = True +else: + onlyLesions = False + +if useGPU: + cudnn.enabled = True +else: + cudnn.enabled = False + +if experiment != 'None': + snapshot_prefix = 'EXP3D' + '_' + experiment + '_' + loss_name + '_' + str(train_method) +else: + if arch_id == 0: + snapshot_prefix = 'DL3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime() + elif arch_id == 1: + snapshot_prefix = 'UNET3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime() + elif arch_id == 2: + snapshot_prefix = 'HR3D' + loss_name + '_' + str(train_method) + '_' + PP.getTime() +to_center_pixel = False +center_pixel_folder_path, locs_lesion, locs_other = (None, None, None) +if train_method == 2: + to_center_pixel = True + if not os.path.exists(os.path.join(main_folder_path, 'centerPixelPatches' + postfix + '_' + str(patch_size))): + print('Pixel patch folder does not exist') + sys.exit() +#load few files +img_list = PP.read_file(list_path) + +results_folder = 'train_results/' +log_file_path = os.path.join(results_folder, 'logs', snapshot_prefix + '_log.txt') +model_file_path = os.path.join(results_folder, 'models', snapshot_prefix + '_best.pth') + +logfile = open(log_file_path, 'w+') +info_run = "arch ID: {:d} | max iters: {:10d} | max iters stage 0 : {:10d} | train method : {} | lr : {}".format(arch_id, max_iter, max_iter_stage0, train_method, base_lr) +logfile.write(info_run + '\n') +logfile.flush() + +def lr_poly(base_lr, iter,max_iter,power): + return base_lr*((1-float(iter)/max_iter)**(power)) + +def modelInit(): + isPriv = False + if arch_id > 10: + isPriv = True + + if experiment != 'None': + dilation_arr, isPriv, withASPP = PP.getExperimentInfo(experiment) + model = exp_net_3D.getExpNet(num_labels, dilation_arr, isPriv, NoLabels2 = num_labels2, withASPP = withASPP) + elif arch_id == 0: + model = deeplab_resnet_3D.Res_Deeplab(num_labels) + elif arch_id == 1: + model = unet_3D.UNet3D(1, num_labels) + elif arch_id == 2: + model = highresnet_3D.getHRNet(num_labels) + + if model_path != 'none': + if useGPU: + #loading on GPU when model was saved on GPU + saved_state_dict = torch.load(model_path) + else: + #loading on CPU when model was saved on GPU + saved_state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(saved_state_dict) + + model.float() + model.eval() # use_global_stats = True + return model, isPriv + +def trainModel(model): + if useGPU: + model.cuda(gpu0) + optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr) + + optimizer.zero_grad() + print(model) + curr_val = 0 + best_val = 0 + val_change = False + loss_arr = np.zeros([iter_size]) + loss_arr_i = 0 + stage = 0 + print('---------------') + print('STAGE ' + str(stage)) + print('---------------') + + for iter in range(iter_low, iter_high): + if iter > max_iter_stage0 and stage != 1: + print('---------------') + print('Stage 1') + print('---------------') + stage = 1 + + if train_method == 0: + img_b, label_b, _ = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, + main_folder_path = '../Data/MS2017b/') + elif train_method == 1 or train_method == 2: + if stage == 0: + batch_size = 1 + img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix) + else: + batch_size = 1 + img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix) + else: + print('Invalid training method format') + sys.exit() + + if stage == 0: + img_b, label_b = AUG.augmentPatchLossLess([img_b, label_b]) + img_b, label_b = AUG.augmentPatchLossy([img_b, label_b]) + #img_b, label_b = AUG.augmentPatchLossless(img_b, label_b) + #img_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 + #label_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 + #batch_num should be 1 since too memory intensive + + label_b = label_b.astype(np.int64) + #convert label from (batch_num x 1 x dim1 x dim2 x dim3) + # to ((batch_numxdim1*dim2*dim3) x 3) (one hot) + temp = label_b.reshape([-1]) + label_b = np.zeros([temp.size, num_labels]) + label_b[np.arange(temp.size),temp] = 1 + label_b = torch.from_numpy(label_b).float() + + imgs = torch.from_numpy(img_b).float() + + if useGPU: + imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0) + else: + imgs, label_b = Variable(imgs), Variable(label_b) + + #--------------------------------------------- + #out size is (1, 3, dim1, dim2, dim3) + #--------------------------------------------- + out = model(imgs) + out = out.permute(0,2,3,4,1).contiguous() + out = out.view(-1, num_labels) + #--------------------------------------------- + #out size is (1 * dim1 * dim2 * dim3, 3) + #--------------------------------------------- + + #loss function + m = nn.Softmax() + loss = lossF.simple_dice_loss3D(m(out), label_b) + + loss /= iter_size + loss.backward() + + loss_val = loss.data.cpu().numpy() + loss_arr[loss_arr_i] = loss_val + loss_arr_i = (loss_arr_i + 1) % iter_size + + if iter % 1 == 0: + if val_change: + print "iter = {:6d}/{:6d} Loss: {:1.6f} Val Score: {:1.6f} \r".format(iter-1, max_iter, float(loss_val)*iter_size, curr_val), + sys.stdout.flush() + print "" + val_change = False + print "iter = {:6d}/{:6d} Loss: {:1.6f} Val Score: {:1.6f} \r".format(iter, max_iter, float(loss_val)*iter_size, curr_val), + sys.stdout.flush() + if iter % 1000 == 0: + val_change = True + curr_val = EF.evalModelX(model, num_labels, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5) + if curr_val > best_val: + best_val = curr_val + print('\nSaving better model...') + torch.save(model.state_dict(), model_file_path) + logfile.write("iter = {:6d}/{:6d} Loss: {:1.6f} Val Score: {:1.6f} \n".format(iter, max_iter, np.sum(loss_arr), curr_val)) + logfile.flush() + if iter % iter_size == 0: + optimizer.step() + optimizer.zero_grad() + + del out, loss + +def setupGIFVar(gif_b): + gif_b = gif_b.astype(np.int64) + gif_b = gif_b.reshape([-1]) + gif_b = torch.from_numpy(gif_b).long() + + if useGPU: + gif_b = Variable(gif_b).cuda(gpu0) + else: + gif_b = Variable(gif_b) + return gif_b + +def trainModelPriv(model): + if useGPU: + model.cuda(gpu0) + optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr) + optimizer.zero_grad() + print(model) + curr_val1 = 0 + curr_val2 = 0 + best_val2 = 0 + val_change = False + loss_arr1 = np.zeros([iter_size]) + loss_arr2 = np.zeros([iter_size]) + loss_arr_i = 0 + + stage = 0 + print('---------------') + print('STAGE ' + str(stage)) + print('---------------') + + for iter in range(iter_low, iter_high): + if iter > max_iter_stage0 and stage != 1: + print('---------------') + print('Stage 1') + print('---------------') + stage = 1 + + if train_method == 0: + img_b, label_b, gif_b = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, + main_folder_path = '../Data/MS2017b/', with_priv = True) + elif train_method == 1 or train_method == 2: + if stage == 0: + batch_size = 5 + img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions, + center_pixel = to_center_pixel, + main_folder_path = '../Data/MS2017b/', + postfix=postfix, with_priv= True) + else: + batch_size = 1 + img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, + center_pixel = to_center_pixel, + main_folder_path = '../Data/MS2017b/', + postfix=postfix, with_priv= True) + else: + print('Invalid training method format') + sys.exit() + + img_b, label_b, gif_b = AUG.augmentPatchLossy([img_b, label_b, gif_b]) + + #img_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 + #label_b is of shape (batch_num) x 1 x dim1 x dim2 x dim3 + + label_b = label_b.astype(np.int64) + + #convert label from (batch_num x 1 x dim1 x dim2 x dim3) + # to ((batch_numxdim1*dim2*dim3) x 3) (one hot) + temp = label_b.reshape([-1]) + label_b = np.zeros([temp.size, num_labels]) + label_b[np.arange(temp.size),temp] = 1 + label_b = torch.from_numpy(label_b).float() + + imgs = torch.from_numpy(img_b).float() + + if useGPU: + imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0) + else: + imgs, label_b = Variable(imgs), Variable(label_b) + + gif_b = setupGIFVar(gif_b) + + #--------------------------------------------- + #out size is (1, 3, dim1, dim2, dim3) + #--------------------------------------------- + #out1 is extra info + out1, out2 = model(imgs) + + out1 = out1.permute(0,2,3,4,1).contiguous() + out1 = out1.view(-1, num_labels2) + + out2 = out2.permute(0,2,3,4,1).contiguous() + out2 = out2.view(-1, num_labels) + #--------------------------------------------- + #out size is (1 * dim1 * dim2 * dim3, 3) + #--------------------------------------------- + m2 = nn.Softmax() + loss2 = lossF.simple_dice_loss3D(m2(out2), label_b) + m1 = nn.LogSoftmax() + loss1 = F.nll_loss(m1(out1), gif_b) + + loss1 /= iter_size + loss2 /= iter_size + + torch.autograd.backward([loss1, loss2]) + + loss_val1 = float(loss1.data.cpu().numpy()) + loss_arr1[loss_arr_i] = loss_val1 + + loss_val2 = float(loss2.data.cpu().numpy()) + loss_arr2[loss_arr_i] = loss_val2 + + loss_arr_i = (loss_arr_i + 1) % iter_size + + if iter % 1 == 0: + if val_change: + print "iter = {:6d}/{:6d} Loss_main: {:1.6f} Loss_secondary: {:1.6f} Val Score: {:1.6f} Val Score secondary: {:1.6f} \r".format(iter-1, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1), + sys.stdout.flush() + print "" + val_change = False + print "iter = {:6d}/{:6d} Loss_main: {:1.6f} Loss_secondary: {:1.6f} Val Score main: {:1.6f} Val Score secondary: {:1.6f} \r".format(iter, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1), + sys.stdout.flush() + if iter % 2000 == 0: + val_change = True + curr_val1, curr_val2 = EFP.evalModelX(model, num_labels, num_labels2, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5, priv_eval = True) + if curr_val2 > best_val2: + best_val2 = curr_val2 + torch.save(model.state_dict(), model_file_path) + print('\nSaving better model...') + logfile.write("iter = {:6d}/{:6d} Loss_main: {:1.6f} Loss_secondary: {:1.6f} Val Score main: {:1.6f} Val Score secondary: {:1.6f} \n".format(iter, max_iter, np.sum(loss_arr2), np.sum(loss_arr1), curr_val2, curr_val1)) + logfile.flush() + if iter % iter_size == 0: + optimizer.step() + optimizer.zero_grad() + + del out1, out2, loss1, loss2 + +if __name__ == "__main__": + model, with_priv = modelInit() + if with_priv: + trainModelPriv(model) + else: + trainModel(model) + logfile.close() \ No newline at end of file