Diff of /train.py [000000] .. [f2ca4d]

Switch to unified view

a b/train.py
1
import sys
2
sys.path.append('architectures/deeplab_3D/')
3
sys.path.append('architectures/unet_3D/')
4
sys.path.append('architectures/hrnet_3D/')
5
sys.path.append('architectures/experiment_nets_3D/')
6
sys.path.append('utils/')
7
8
import torch
9
import torch.nn as nn
10
from torch.autograd import Variable
11
import torch.backends.cudnn as cudnn
12
import torch.nn.functional as F
13
import torch.optim as optim
14
15
import numpy as np
16
import scipy.misc
17
import os
18
from tqdm import *
19
import random
20
from random import randint
21
from docopt import docopt
22
23
import deeplab_resnet_3D
24
import unet_3D
25
import highresnet_3D
26
import exp_net_3D
27
28
import lossF
29
import PP
30
import augmentations as AUG
31
32
import nibabel as nib
33
import evalF as EF
34
import evalFP as EFP
35
36
37
docstr = """Write something here
38
39
Usage:
40
    train.py [options]
41
42
Options:
43
    -h, --help                  Print this message
44
    --archId=<int>              Architecture to run, 0 is DeepLab 3D, 1 is U-net3D, 2 is HRNet [default: 2]
45
    --trainMethod=<int>         0 is full image, 1 is by patches (random), 2 is by patches (center pixel) [default: 1]
46
    --lossFunction=<str>        Loss function name. 'dice' option available [default: dice]
47
    --imgSize=<str>             Image size [default: 200x200x100]
48
    --mainFolderPath=<str>      Main folder path [default: ../Data/MS2017b/]
49
    --patchSize=<int>           Size of the patch [default: 60]
50
    --patchSizeStage0=<int>     Size of the patch at stage 0 [default: 41]
51
    --namePostfix=<str>         Postfix of flair. i.e. to use FLAIR_s postfix is _s. This also determines the train file [default: _200x200x100orig]
52
    --modelPath=<str>           Path of model to continue training on [default: none]
53
    --NoLabels=<int>            The number of different labels in training data, including background [default: 2]
54
    --maxIter=<int>             Maximum number of iterations [default: 20000]
55
    --maxIterStage0=<int>       Maximum number of iterations for stage 0 training [default: -1]
56
    -i, --iterSize=<int>        Num iters to accumulate gradients over [default: 1]
57
    --lr=<float>                Learning Rate [default: 0.0001]
58
    --gpu0=<int>                GPU number [default: 0]
59
    --useGPU=<int>              Use GPU or not [default: 0]
60
    --experiment=<str>          Specify experiment instead to run. e.g. 1x1x1x1x1x1_1_0 means 1 dilations all 6 blocks, with priv, no ASPP [default: None]
61
"""
62
args = docopt(docstr, version='v0.1')
63
print(args)
64
65
arch_id = int(args['--archId'])
66
train_method = int(args['--trainMethod'])
67
loss_name = args['--lossFunction']
68
img_dims = np.array(args['--imgSize'].split('x'), dtype=np.int64)
69
main_folder_path = args['--mainFolderPath']
70
patch_size = int(args['--patchSize'])
71
72
postfix = args['--namePostfix']
73
model_path = args['--modelPath']
74
num_labels = int(args['--NoLabels'])
75
max_iter = int(args['--maxIter']) 
76
77
iter_size = int(args['--iterSize']) 
78
base_lr = float(args['--lr'])
79
experiment = str(args['--experiment'])
80
gpu0 = int(args['--gpu0'])
81
useGPU = int(args['--useGPU'])
82
batch_size = 1
83
#img_dims = [197, 233, 189]
84
list_path = main_folder_path + 'train' + postfix + '.txt'
85
print('READING from ', list_path)
86
img_type_path = 'pre/FLAIR' + postfix + '.nii.gz'
87
gt_type_path = 'wmh' + postfix + '.nii.gz'
88
89
90
patch_size_stage0 = int(args['--patchSizeStage0'])
91
max_iter_stage0 = int(args['--maxIterStage0'])
92
93
iter_low = 1
94
iter_high = max_iter + 1
95
96
if model_path != 'none':
97
    iter_low = int(model_path.split('iter_')[-1].replace('.pth','')) + 1
98
    if iter_low >= iter_high:
99
        print('Model already at ' + str(iter_low) + ' iterations. Change max iter size')
100
        sys.exit()
101
102
num_labels2 = 209
103
#change to 0 to enable stage 0 patch learning
104
105
if num_labels == 2:
106
    onlyLesions = True
107
else:
108
    onlyLesions = False
109
110
if useGPU:
111
    cudnn.enabled = True
112
else:
113
    cudnn.enabled = False
114
115
if experiment != 'None':
116
    snapshot_prefix = 'EXP3D' + '_' + experiment + '_' + loss_name + '_' + str(train_method)
117
else:
118
    if arch_id == 0:
119
        snapshot_prefix = 'DL3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime()
120
    elif arch_id == 1:
121
        snapshot_prefix = 'UNET3D_' + loss_name + '_' + str(train_method) + '_' + PP.getTime()
122
    elif arch_id == 2:
123
        snapshot_prefix = 'HR3D' + loss_name + '_' + str(train_method) + '_' + PP.getTime()
124
to_center_pixel = False
125
center_pixel_folder_path, locs_lesion, locs_other = (None, None, None)
126
if train_method == 2:
127
    to_center_pixel = True
128
    if not os.path.exists(os.path.join(main_folder_path, 'centerPixelPatches' + postfix + '_' + str(patch_size))):
129
        print('Pixel patch folder does not exist')
130
        sys.exit()
131
#load few files
132
img_list = PP.read_file(list_path)
133
134
results_folder = 'train_results/'
135
log_file_path = os.path.join(results_folder, 'logs', snapshot_prefix + '_log.txt')
136
model_file_path = os.path.join(results_folder, 'models', snapshot_prefix + '_best.pth')
137
138
logfile = open(log_file_path, 'w+')
139
info_run = "arch ID: {:d} | max iters: {:10d} | max iters stage 0 : {:10d} | train method : {} | lr : {}".format(arch_id, max_iter, max_iter_stage0, train_method, base_lr)
140
logfile.write(info_run + '\n')
141
logfile.flush()
142
143
def lr_poly(base_lr, iter,max_iter,power):
144
    return base_lr*((1-float(iter)/max_iter)**(power))
145
146
def modelInit():
147
    isPriv = False
148
    if arch_id > 10:
149
        isPriv = True
150
151
    if experiment != 'None':
152
        dilation_arr, isPriv, withASPP = PP.getExperimentInfo(experiment)
153
        model = exp_net_3D.getExpNet(num_labels, dilation_arr, isPriv, NoLabels2 = num_labels2, withASPP = withASPP)
154
    elif arch_id == 0:
155
        model = deeplab_resnet_3D.Res_Deeplab(num_labels)
156
    elif arch_id == 1:
157
        model = unet_3D.UNet3D(1, num_labels)
158
    elif arch_id == 2:
159
        model = highresnet_3D.getHRNet(num_labels)
160
161
    if model_path != 'none':
162
        if useGPU:
163
            #loading on GPU when model was saved on GPU
164
            saved_state_dict = torch.load(model_path)
165
        else:
166
            #loading on CPU when model was saved on GPU
167
            saved_state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
168
        model.load_state_dict(saved_state_dict)
169
170
    model.float()
171
    model.eval() # use_global_stats = True
172
    return model, isPriv
173
174
def trainModel(model):
175
    if useGPU:
176
        model.cuda(gpu0)
177
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr)
178
179
    optimizer.zero_grad()
180
    print(model)
181
    curr_val = 0
182
    best_val = 0
183
    val_change = False
184
    loss_arr = np.zeros([iter_size])
185
    loss_arr_i = 0
186
    stage = 0
187
    print('---------------')
188
    print('STAGE ' + str(stage))
189
    print('---------------')
190
191
    for iter in range(iter_low, iter_high):
192
        if iter > max_iter_stage0 and stage != 1:
193
            print('---------------')
194
            print('Stage 1')
195
            print('---------------')
196
            stage = 1
197
198
        if train_method == 0:
199
            img_b, label_b, _ = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, 
200
                                                    main_folder_path = '../Data/MS2017b/')
201
        elif train_method == 1 or train_method == 2:
202
            if stage == 0:
203
                batch_size = 1
204
                img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix)
205
            else:
206
                batch_size = 1
207
                img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix)
208
        else:
209
            print('Invalid training method format')
210
            sys.exit()
211
212
        if stage == 0:
213
            img_b, label_b = AUG.augmentPatchLossLess([img_b, label_b])
214
        img_b, label_b = AUG.augmentPatchLossy([img_b, label_b])
215
        #img_b, label_b = AUG.augmentPatchLossless(img_b, label_b)
216
        #img_b is of shape      (batch_num) x 1 x dim1 x dim2 x dim3
217
        #label_b is of shape    (batch_num) x 1 x dim1 x dim2 x dim3
218
        #batch_num should be 1 since too memory intensive
219
220
        label_b = label_b.astype(np.int64)
221
        #convert label from (batch_num x 1 x dim1 x dim2 x dim3)
222
        #               to  ((batch_numxdim1*dim2*dim3) x 3) (one hot)
223
        temp = label_b.reshape([-1])
224
        label_b = np.zeros([temp.size, num_labels])
225
        label_b[np.arange(temp.size),temp] = 1
226
        label_b = torch.from_numpy(label_b).float()
227
228
        imgs = torch.from_numpy(img_b).float()
229
230
        if useGPU:
231
            imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0)
232
        else:
233
            imgs, label_b = Variable(imgs), Variable(label_b)
234
235
        #---------------------------------------------
236
        #out size is      (1, 3, dim1, dim2, dim3)
237
        #---------------------------------------------
238
        out = model(imgs)
239
        out = out.permute(0,2,3,4,1).contiguous()
240
        out = out.view(-1, num_labels)
241
        #---------------------------------------------
242
        #out size is      (1 * dim1 * dim2 * dim3, 3)
243
        #---------------------------------------------
244
245
        #loss function
246
        m = nn.Softmax()
247
        loss = lossF.simple_dice_loss3D(m(out), label_b)
248
249
        loss /= iter_size
250
        loss.backward()
251
252
        loss_val = loss.data.cpu().numpy()
253
        loss_arr[loss_arr_i] = loss_val
254
        loss_arr_i = (loss_arr_i + 1) % iter_size
255
256
        if iter % 1 == 0:
257
            if val_change:
258
                print "iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \r".format(iter-1, max_iter, float(loss_val)*iter_size, curr_val),
259
                sys.stdout.flush()
260
                print ""
261
                val_change = False
262
            print "iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \r".format(iter, max_iter, float(loss_val)*iter_size, curr_val),
263
            sys.stdout.flush()
264
        if iter % 1000 == 0:
265
            val_change = True
266
            curr_val = EF.evalModelX(model, num_labels, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5)
267
            if curr_val > best_val:
268
                best_val = curr_val
269
                print('\nSaving better model...')
270
                torch.save(model.state_dict(), model_file_path)
271
            logfile.write("iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \n".format(iter, max_iter, np.sum(loss_arr), curr_val))
272
            logfile.flush()
273
        if iter % iter_size == 0:
274
            optimizer.step()
275
            optimizer.zero_grad()
276
277
        del out, loss
278
279
def setupGIFVar(gif_b):
280
    gif_b = gif_b.astype(np.int64)
281
    gif_b = gif_b.reshape([-1])
282
    gif_b = torch.from_numpy(gif_b).long()
283
284
    if useGPU:
285
        gif_b = Variable(gif_b).cuda(gpu0)
286
    else:
287
        gif_b = Variable(gif_b)
288
    return gif_b
289
290
def trainModelPriv(model):
291
    if useGPU:
292
        model.cuda(gpu0)
293
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr)
294
    optimizer.zero_grad()
295
    print(model)
296
    curr_val1 = 0
297
    curr_val2 = 0
298
    best_val2 = 0
299
    val_change = False
300
    loss_arr1 = np.zeros([iter_size])
301
    loss_arr2 = np.zeros([iter_size])
302
    loss_arr_i = 0
303
304
    stage = 0
305
    print('---------------')
306
    print('STAGE ' + str(stage))
307
    print('---------------')
308
309
    for iter in range(iter_low, iter_high):
310
        if iter > max_iter_stage0 and stage != 1:
311
            print('---------------')
312
            print('Stage 1')
313
            print('---------------')
314
            stage = 1
315
316
        if train_method == 0:
317
            img_b, label_b, gif_b = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, 
318
                                                    main_folder_path = '../Data/MS2017b/', with_priv = True)
319
        elif train_method == 1 or train_method == 2:
320
            if stage == 0:
321
                batch_size = 5
322
                img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions,
323
                                                            center_pixel = to_center_pixel, 
324
                                                            main_folder_path = '../Data/MS2017b/', 
325
                                                            postfix=postfix, with_priv= True)
326
            else:
327
                batch_size = 1
328
                img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, 
329
                                                    center_pixel = to_center_pixel, 
330
                                                    main_folder_path = '../Data/MS2017b/', 
331
                                                    postfix=postfix, with_priv= True)
332
        else:
333
            print('Invalid training method format')
334
            sys.exit()
335
336
        img_b, label_b, gif_b = AUG.augmentPatchLossy([img_b, label_b, gif_b])
337
338
        #img_b is of shape      (batch_num) x 1 x dim1 x dim2 x dim3
339
        #label_b is of shape    (batch_num) x 1 x dim1 x dim2 x dim3
340
341
        label_b = label_b.astype(np.int64)
342
343
        #convert label from (batch_num x 1 x dim1 x dim2 x dim3)
344
        #               to  ((batch_numxdim1*dim2*dim3) x 3) (one hot)
345
        temp = label_b.reshape([-1])
346
        label_b = np.zeros([temp.size, num_labels])
347
        label_b[np.arange(temp.size),temp] = 1
348
        label_b = torch.from_numpy(label_b).float()
349
350
        imgs = torch.from_numpy(img_b).float()
351
352
        if useGPU:
353
            imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0)
354
        else:
355
            imgs, label_b = Variable(imgs), Variable(label_b)
356
357
        gif_b = setupGIFVar(gif_b)
358
359
        #---------------------------------------------
360
        #out size is      (1, 3, dim1, dim2, dim3)
361
        #---------------------------------------------
362
        #out1 is extra info
363
        out1, out2 = model(imgs)
364
365
        out1 = out1.permute(0,2,3,4,1).contiguous()
366
        out1 = out1.view(-1, num_labels2)
367
368
        out2 = out2.permute(0,2,3,4,1).contiguous()
369
        out2 = out2.view(-1, num_labels)
370
        #---------------------------------------------
371
        #out size is      (1 * dim1 * dim2 * dim3, 3)
372
        #---------------------------------------------
373
        m2 = nn.Softmax()
374
        loss2 = lossF.simple_dice_loss3D(m2(out2), label_b)
375
        m1 = nn.LogSoftmax()
376
        loss1 = F.nll_loss(m1(out1), gif_b)
377
378
        loss1 /= iter_size
379
        loss2 /= iter_size
380
381
        torch.autograd.backward([loss1, loss2])
382
383
        loss_val1 = float(loss1.data.cpu().numpy())
384
        loss_arr1[loss_arr_i] = loss_val1
385
386
        loss_val2 = float(loss2.data.cpu().numpy())
387
        loss_arr2[loss_arr_i] = loss_val2
388
389
        loss_arr_i = (loss_arr_i + 1) % iter_size
390
391
        if iter % 1 == 0:
392
            if val_change:
393
                print "iter = {:6d}/{:6d}       Loss_main: {:1.6f}    Loss_secondary: {:1.6f}       Val Score: {:1.6f}      Val Score secondary: {:1.6f}     \r".format(iter-1, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1),
394
                sys.stdout.flush()
395
                print ""
396
                val_change = False
397
            print "iter = {:6d}/{:6d}       Loss_main: {:1.6f}      Loss_secondary: {:1.6f}       Val Score main: {:1.6f}      Val Score secondary: {:1.6f}     \r".format(iter, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1),
398
            sys.stdout.flush()
399
        if iter % 2000 == 0:
400
            val_change = True
401
            curr_val1, curr_val2 = EFP.evalModelX(model, num_labels, num_labels2, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5, priv_eval = True)
402
            if curr_val2 > best_val2:
403
                best_val2 = curr_val2
404
                torch.save(model.state_dict(), model_file_path)
405
                print('\nSaving better model...')
406
            logfile.write("iter = {:6d}/{:6d}       Loss_main: {:1.6f}      Loss_secondary: {:1.6f}       Val Score main: {:1.6f}      Val Score secondary: {:1.6f}  \n".format(iter, max_iter, np.sum(loss_arr2), np.sum(loss_arr1), curr_val2, curr_val1))
407
            logfile.flush()
408
        if iter % iter_size == 0:
409
            optimizer.step()
410
            optimizer.zero_grad()
411
412
        del out1, out2, loss1, loss2
413
414
if __name__ == "__main__":
415
    model, with_priv = modelInit()
416
    if with_priv:
417
        trainModelPriv(model)
418
    else:
419
        trainModel(model)
420
    logfile.close()