Diff of /train.py [000000] .. [f2ca4d]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,420 @@
+import sys
+sys.path.append('architectures/deeplab_3D/')
+sys.path.append('architectures/unet_3D/')
+sys.path.append('architectures/hrnet_3D/')
+sys.path.append('architectures/experiment_nets_3D/')
+sys.path.append('utils/')
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import torch.backends.cudnn as cudnn
+import torch.nn.functional as F
+import torch.optim as optim
+
+import numpy as np
+import scipy.misc
+import os
+from tqdm import *
+import random
+from random import randint
+from docopt import docopt
+
+import deeplab_resnet_3D
+import unet_3D
+import highresnet_3D
+import exp_net_3D
+
+import lossF
+import PP
+import augmentations as AUG
+
+import nibabel as nib
+import evalF as EF
+import evalFP as EFP
+
+
+docstr = """Write something here
+
+Usage:
+    train.py [options]
+
+Options:
+    -h, --help                  Print this message
+    --archId=<int>              Architecture to run, 0 is DeepLab 3D, 1 is U-net3D, 2 is HRNet [default: 2]
+    --trainMethod=<int>         0 is full image, 1 is by patches (random), 2 is by patches (center pixel) [default: 1]
+    --lossFunction=<str>        Loss function name. 'dice' option available [default: dice]
+    --imgSize=<str>             Image size [default: 200x200x100]
+    --mainFolderPath=<str>      Main folder path [default: ../Data/MS2017b/]
+    --patchSize=<int>           Size of the patch [default: 60]
+    --patchSizeStage0=<int>     Size of the patch at stage 0 [default: 41]
+    --namePostfix=<str>         Postfix of flair. i.e. to use FLAIR_s postfix is _s. This also determines the train file [default: _200x200x100orig]
+    --modelPath=<str>           Path of model to continue training on [default: none]
+    --NoLabels=<int>            The number of different labels in training data, including background [default: 2]
+    --maxIter=<int>             Maximum number of iterations [default: 20000]
+    --maxIterStage0=<int>       Maximum number of iterations for stage 0 training [default: -1]
+    -i, --iterSize=<int>        Num iters to accumulate gradients over [default: 1]
+    --lr=<float>                Learning Rate [default: 0.0001]
+    --gpu0=<int>                GPU number [default: 0]
+    --useGPU=<int>              Use GPU or not [default: 0]
+    --experiment=<str>          Specify experiment instead to run. e.g. 1x1x1x1x1x1_1_0 means 1 dilations all 6 blocks, with priv, no ASPP [default: None]
+"""
+args = docopt(docstr, version='v0.1')
+print(args)
+
+arch_id = int(args['--archId'])
+train_method = int(args['--trainMethod'])
+loss_name = args['--lossFunction']
+img_dims = np.array(args['--imgSize'].split('x'), dtype=np.int64)
+main_folder_path = args['--mainFolderPath']
+patch_size = int(args['--patchSize'])
+
+postfix = args['--namePostfix']
+model_path = args['--modelPath']
+num_labels = int(args['--NoLabels'])
+max_iter = int(args['--maxIter']) 
+
+iter_size = int(args['--iterSize']) 
+base_lr = float(args['--lr'])
+experiment = str(args['--experiment'])
+gpu0 = int(args['--gpu0'])
+useGPU = int(args['--useGPU'])
+batch_size = 1
+#img_dims = [197, 233, 189]
+list_path = main_folder_path + 'train' + postfix + '.txt'
+print('READING from ', list_path)
+img_type_path = 'pre/FLAIR' + postfix + '.nii.gz'
+gt_type_path = 'wmh' + postfix + '.nii.gz'
+
+
+patch_size_stage0 = int(args['--patchSizeStage0'])
+max_iter_stage0 = int(args['--maxIterStage0'])
+
+iter_low = 1
+iter_high = max_iter + 1
+
+if model_path != 'none':
+    iter_low = int(model_path.split('iter_')[-1].replace('.pth','')) + 1
+    if iter_low >= iter_high:
+        print('Model already at ' + str(iter_low) + ' iterations. Change max iter size')
+        sys.exit()
+
+num_labels2 = 209
+#change to 0 to enable stage 0 patch learning
+
+if num_labels == 2:
+    onlyLesions = True
+else:
+    onlyLesions = False
+
+if useGPU:
+    cudnn.enabled = True
+else:
+    cudnn.enabled = False
+
+if experiment != 'None':
+    snapshot_prefix = 'EXP3D' + '_' + experiment + '_' + loss_name + '_' + str(train_method)
+else:
+    if arch_id == 0:
+        snapshot_prefix = 'DL3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime()
+    elif arch_id == 1:
+        snapshot_prefix = 'UNET3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime()
+    elif arch_id == 2:
+        snapshot_prefix = 'HR3D' + loss_name + '_' + str(train_method) + '_' + PP.getTime()
+to_center_pixel = False
+center_pixel_folder_path, locs_lesion, locs_other = (None, None, None)
+if train_method == 2:
+    to_center_pixel = True
+    if not os.path.exists(os.path.join(main_folder_path, 'centerPixelPatches' + postfix + '_' + str(patch_size))):
+        print('Pixel patch folder does not exist')
+        sys.exit()
+#load few files
+img_list = PP.read_file(list_path)
+
+results_folder = 'train_results/'
+log_file_path = os.path.join(results_folder, 'logs', snapshot_prefix + '_log.txt')
+model_file_path = os.path.join(results_folder, 'models', snapshot_prefix + '_best.pth')
+
+logfile = open(log_file_path, 'w+')
+info_run = "arch ID: {:d} | max iters: {:10d} | max iters stage 0 : {:10d} | train method : {} | lr : {}".format(arch_id, max_iter, max_iter_stage0, train_method, base_lr)
+logfile.write(info_run + '\n')
+logfile.flush()
+
+def lr_poly(base_lr, iter,max_iter,power):
+    return base_lr*((1-float(iter)/max_iter)**(power))
+
+def modelInit():
+    isPriv = False
+    if arch_id > 10:
+        isPriv = True
+
+    if experiment != 'None':
+        dilation_arr, isPriv, withASPP = PP.getExperimentInfo(experiment)
+        model = exp_net_3D.getExpNet(num_labels, dilation_arr, isPriv, NoLabels2 = num_labels2, withASPP = withASPP)
+    elif arch_id == 0:
+        model = deeplab_resnet_3D.Res_Deeplab(num_labels)
+    elif arch_id == 1:
+        model = unet_3D.UNet3D(1, num_labels)
+    elif arch_id == 2:
+        model = highresnet_3D.getHRNet(num_labels)
+
+    if model_path != 'none':
+        if useGPU:
+            #loading on GPU when model was saved on GPU
+            saved_state_dict = torch.load(model_path)
+        else:
+            #loading on CPU when model was saved on GPU
+            saved_state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
+        model.load_state_dict(saved_state_dict)
+
+    model.float()
+    model.eval() # use_global_stats = True
+    return model, isPriv
+
+def trainModel(model):
+    if useGPU:
+        model.cuda(gpu0)
+    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr)
+
+    optimizer.zero_grad()
+    print(model)
+    curr_val = 0
+    best_val = 0
+    val_change = False
+    loss_arr = np.zeros([iter_size])
+    loss_arr_i = 0
+    stage = 0
+    print('---------------')
+    print('STAGE ' + str(stage))
+    print('---------------')
+
+    for iter in range(iter_low, iter_high):
+        if iter > max_iter_stage0 and stage != 1:
+            print('---------------')
+            print('Stage 1')
+            print('---------------')
+            stage = 1
+
+        if train_method == 0:
+            img_b, label_b, _ = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, 
+                                                    main_folder_path = '../Data/MS2017b/')
+        elif train_method == 1 or train_method == 2:
+            if stage == 0:
+                batch_size = 1
+                img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix)
+            else:
+                batch_size = 1
+                img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix)
+        else:
+            print('Invalid training method format')
+            sys.exit()
+
+        if stage == 0:
+            img_b, label_b = AUG.augmentPatchLossLess([img_b, label_b])
+        img_b, label_b = AUG.augmentPatchLossy([img_b, label_b])
+        #img_b, label_b = AUG.augmentPatchLossless(img_b, label_b)
+        #img_b is of shape      (batch_num) x 1 x dim1 x dim2 x dim3
+        #label_b is of shape    (batch_num) x 1 x dim1 x dim2 x dim3
+        #batch_num should be 1 since too memory intensive
+
+        label_b = label_b.astype(np.int64)
+        #convert label from (batch_num x 1 x dim1 x dim2 x dim3)
+        #               to  ((batch_numxdim1*dim2*dim3) x 3) (one hot)
+        temp = label_b.reshape([-1])
+        label_b = np.zeros([temp.size, num_labels])
+        label_b[np.arange(temp.size),temp] = 1
+        label_b = torch.from_numpy(label_b).float()
+
+        imgs = torch.from_numpy(img_b).float()
+
+        if useGPU:
+            imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0)
+        else:
+            imgs, label_b = Variable(imgs), Variable(label_b)
+
+        #---------------------------------------------
+        #out size is      (1, 3, dim1, dim2, dim3)
+        #---------------------------------------------
+        out = model(imgs)
+        out = out.permute(0,2,3,4,1).contiguous()
+        out = out.view(-1, num_labels)
+        #---------------------------------------------
+        #out size is      (1 * dim1 * dim2 * dim3, 3)
+        #---------------------------------------------
+
+        #loss function
+        m = nn.Softmax()
+        loss = lossF.simple_dice_loss3D(m(out), label_b)
+
+        loss /= iter_size
+        loss.backward()
+
+        loss_val = loss.data.cpu().numpy()
+        loss_arr[loss_arr_i] = loss_val
+        loss_arr_i = (loss_arr_i + 1) % iter_size
+
+        if iter % 1 == 0:
+            if val_change:
+                print "iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \r".format(iter-1, max_iter, float(loss_val)*iter_size, curr_val),
+                sys.stdout.flush()
+                print ""
+                val_change = False
+            print "iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \r".format(iter, max_iter, float(loss_val)*iter_size, curr_val),
+            sys.stdout.flush()
+        if iter % 1000 == 0:
+            val_change = True
+            curr_val = EF.evalModelX(model, num_labels, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5)
+            if curr_val > best_val:
+                best_val = curr_val
+                print('\nSaving better model...')
+                torch.save(model.state_dict(), model_file_path)
+            logfile.write("iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \n".format(iter, max_iter, np.sum(loss_arr), curr_val))
+            logfile.flush()
+        if iter % iter_size == 0:
+            optimizer.step()
+            optimizer.zero_grad()
+
+        del out, loss
+
+def setupGIFVar(gif_b):
+    gif_b = gif_b.astype(np.int64)
+    gif_b = gif_b.reshape([-1])
+    gif_b = torch.from_numpy(gif_b).long()
+
+    if useGPU:
+        gif_b = Variable(gif_b).cuda(gpu0)
+    else:
+        gif_b = Variable(gif_b)
+    return gif_b
+
+def trainModelPriv(model):
+    if useGPU:
+        model.cuda(gpu0)
+    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr)
+    optimizer.zero_grad()
+    print(model)
+    curr_val1 = 0
+    curr_val2 = 0
+    best_val2 = 0
+    val_change = False
+    loss_arr1 = np.zeros([iter_size])
+    loss_arr2 = np.zeros([iter_size])
+    loss_arr_i = 0
+
+    stage = 0
+    print('---------------')
+    print('STAGE ' + str(stage))
+    print('---------------')
+
+    for iter in range(iter_low, iter_high):
+        if iter > max_iter_stage0 and stage != 1:
+            print('---------------')
+            print('Stage 1')
+            print('---------------')
+            stage = 1
+
+        if train_method == 0:
+            img_b, label_b, gif_b = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, 
+                                                    main_folder_path = '../Data/MS2017b/', with_priv = True)
+        elif train_method == 1 or train_method == 2:
+            if stage == 0:
+                batch_size = 5
+                img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions,
+                                                            center_pixel = to_center_pixel, 
+                                                            main_folder_path = '../Data/MS2017b/', 
+                                                            postfix=postfix, with_priv= True)
+            else:
+                batch_size = 1
+                img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, 
+                                                    center_pixel = to_center_pixel, 
+                                                    main_folder_path = '../Data/MS2017b/', 
+                                                    postfix=postfix, with_priv= True)
+        else:
+            print('Invalid training method format')
+            sys.exit()
+
+        img_b, label_b, gif_b = AUG.augmentPatchLossy([img_b, label_b, gif_b])
+
+        #img_b is of shape      (batch_num) x 1 x dim1 x dim2 x dim3
+        #label_b is of shape    (batch_num) x 1 x dim1 x dim2 x dim3
+
+        label_b = label_b.astype(np.int64)
+
+        #convert label from (batch_num x 1 x dim1 x dim2 x dim3)
+        #               to  ((batch_numxdim1*dim2*dim3) x 3) (one hot)
+        temp = label_b.reshape([-1])
+        label_b = np.zeros([temp.size, num_labels])
+        label_b[np.arange(temp.size),temp] = 1
+        label_b = torch.from_numpy(label_b).float()
+
+        imgs = torch.from_numpy(img_b).float()
+
+        if useGPU:
+            imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0)
+        else:
+            imgs, label_b = Variable(imgs), Variable(label_b)
+
+        gif_b = setupGIFVar(gif_b)
+
+        #---------------------------------------------
+        #out size is      (1, 3, dim1, dim2, dim3)
+        #---------------------------------------------
+        #out1 is extra info
+        out1, out2 = model(imgs)
+
+        out1 = out1.permute(0,2,3,4,1).contiguous()
+        out1 = out1.view(-1, num_labels2)
+
+        out2 = out2.permute(0,2,3,4,1).contiguous()
+        out2 = out2.view(-1, num_labels)
+        #---------------------------------------------
+        #out size is      (1 * dim1 * dim2 * dim3, 3)
+        #---------------------------------------------
+        m2 = nn.Softmax()
+        loss2 = lossF.simple_dice_loss3D(m2(out2), label_b)
+        m1 = nn.LogSoftmax()
+        loss1 = F.nll_loss(m1(out1), gif_b)
+
+        loss1 /= iter_size
+        loss2 /= iter_size
+
+        torch.autograd.backward([loss1, loss2])
+
+        loss_val1 = float(loss1.data.cpu().numpy())
+        loss_arr1[loss_arr_i] = loss_val1
+
+        loss_val2 = float(loss2.data.cpu().numpy())
+        loss_arr2[loss_arr_i] = loss_val2
+
+        loss_arr_i = (loss_arr_i + 1) % iter_size
+
+        if iter % 1 == 0:
+            if val_change:
+                print "iter = {:6d}/{:6d}       Loss_main: {:1.6f}    Loss_secondary: {:1.6f}       Val Score: {:1.6f}      Val Score secondary: {:1.6f}     \r".format(iter-1, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1),
+                sys.stdout.flush()
+                print ""
+                val_change = False
+            print "iter = {:6d}/{:6d}       Loss_main: {:1.6f}      Loss_secondary: {:1.6f}       Val Score main: {:1.6f}      Val Score secondary: {:1.6f}     \r".format(iter, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1),
+            sys.stdout.flush()
+        if iter % 2000 == 0:
+            val_change = True
+            curr_val1, curr_val2 = EFP.evalModelX(model, num_labels, num_labels2, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5, priv_eval = True)
+            if curr_val2 > best_val2:
+                best_val2 = curr_val2
+                torch.save(model.state_dict(), model_file_path)
+                print('\nSaving better model...')
+            logfile.write("iter = {:6d}/{:6d}       Loss_main: {:1.6f}      Loss_secondary: {:1.6f}       Val Score main: {:1.6f}      Val Score secondary: {:1.6f}  \n".format(iter, max_iter, np.sum(loss_arr2), np.sum(loss_arr1), curr_val2, curr_val1))
+            logfile.flush()
+        if iter % iter_size == 0:
+            optimizer.step()
+            optimizer.zero_grad()
+
+        del out1, out2, loss1, loss2
+
+if __name__ == "__main__":
+    model, with_priv = modelInit()
+    if with_priv:
+        trainModelPriv(model)
+    else:
+        trainModel(model)
+    logfile.close()
\ No newline at end of file