Diff of /train.py [000000] .. [1b6491]

Switch to unified view

a b/train.py
1
# ==============================================================================
2
# Copyright (C) 2020 Vladimir Juras, Ravinder Regatte and Cem M. Deniz
3
#
4
# This file is part of 2019_IWOAI_Challenge
5
#
6
# This program is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU Affero General Public License as published
8
# by the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
11
# This program is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU Affero General Public License for more details.
15
16
# You should have received a copy of the GNU Affero General Public License
17
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
18
# ==============================================================================
19
import tensorflow as tf
20
import tf_utilities as tfut
21
import tf_layers as tflay
22
import models
23
import sys
24
25
import numpy as np
26
import re
27
import time
28
import os
29
from functools import partial
30
31
import h5py
32
from sklearn.model_selection import StratifiedKFold
33
from sklearn.metrics import accuracy_score
34
from sklearn.preprocessing import label_binarize
35
from keras.utils import to_categorical
36
from pathlib import Path
37
38
39
tf.app.flags.DEFINE_boolean('restore', False, 'Whether to restore from previous model.')
40
tf.app.flags.DEFINE_float('lr', 0.00005, 'Initial learning rate.')
41
tf.app.flags.DEFINE_integer('feature', 16, 'Number of root features.')
42
tf.app.flags.DEFINE_string('model', '4atrous248', 'Model name.')
43
tf.app.flags.DEFINE_boolean('val', True, 'Whether to use validation.')
44
tf.app.flags.DEFINE_boolean('full_data', True, 'Whether to use full data set.')
45
tf.app.flags.DEFINE_float('dr', 1.0, 'Learning rate decay rate.')
46
tf.app.flags.DEFINE_integer('reso', 384, 'Image size.')
47
tf.app.flags.DEFINE_integer('slices', 160, 'Number Of Slices')
48
tf.app.flags.DEFINE_string('loss', 'wce', 'Loss name.')
49
tf.app.flags.DEFINE_integer('epoch', 400, 'Number of epochs.')
50
tf.app.flags.DEFINE_boolean('staircase', False, 'If True decay the learning rate at discrete intervals.')
51
tf.app.flags.DEFINE_integer('seed', 1234, 'Graph-level random seed.')
52
tf.app.flags.DEFINE_float('dropout', 1.0, 'Dropout rate when training.')
53
tf.app.flags.DEFINE_string('output_path', None, 'Name of output folder.')
54
tf.app.flags.DEFINE_boolean('resnet', False, 'Whether to use resnet shortcut.')
55
tf.app.flags.DEFINE_boolean('early_stopping', True, 'early stopping feature')
56
tf.app.flags.DEFINE_string('folder', './data', 'Data Folder')
57
tf.app.flags.DEFINE_integer('noImages', -1, 'how many images to train and validate')
58
tf.app.flags.DEFINE_float('switchAccuracy', 0.88, 'Training accuracy switch to Dice loss')
59
tf.app.flags.DEFINE_string('info', ' ', 'add some info to run')
60
61
FLAGS = tf.app.flags.FLAGS
62
63
switchAccuracy = FLAGS.switchAccuracy
64
65
num_classes = 7
66
num_channels = 1
67
68
def _get_cost(logits, batch_y, cost_name='dice', add_regularizers=None, class_weights=None):
69
    flat_logits = tf.reshape(logits, [-1, num_classes])
70
    flat_labels = tf.reshape(batch_y, [-1, num_classes])
71
    
72
    if cost_name == 'cross_entropy':
73
        if class_weights is not None:
74
            weight_map = tf.multiply(flat_labels, class_weights)
75
            weight_map = tf.reduce_sum(weight_map, axis=1)
76
            loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
77
                                                            labels=flat_labels)
78
            weighted_loss = tf.multiply(loss_map, weight_map)
79
            loss = tf.reduce_mean(weighted_loss)
80
        else:
81
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels))
82
83
    elif cost_name == 'dice':
84
        flat_logits = tf.nn.softmax(flat_logits)[:, 1]
85
        flat_labels = flat_labels[:, 1]
86
87
        inse = tf.reduce_sum(flat_logits*flat_labels)
88
        l = tf.reduce_sum(flat_logits*flat_logits)
89
        r = tf.reduce_sum(flat_labels*flat_labels)
90
        dice = 2 *(inse) / (l+r)
91
        loss = 1.0-tf.clip_by_value(dice,0,1-1e-10)
92
93
    elif cost_name == 'dice_multi':
94
        dice_multi = 0
95
        n_classes = num_classes
96
        for index in range(n_classes):
97
            flat_logits_ = tf.nn.softmax(flat_logits)[:, index]
98
            flat_labels_ = flat_labels[:, index]
99
100
            inse = tf.reduce_sum(flat_logits_*flat_labels_)
101
            l = tf.reduce_sum(flat_logits_*flat_logits_)
102
            r = tf.reduce_sum(flat_labels_*flat_labels_)
103
            dice = 2 *(inse) / (l+r)
104
            dice = tf.clip_by_value(dice,0,1-1e-10)
105
106
            dice_multi += dice
107
108
        loss = n_classes*1.0-dice_multi
109
110
111
    elif cost_name == 'dice_multi_noBG':
112
        dice_multi = 0
113
        n_classes = num_classes
114
        for index in range(1,n_classes):
115
            flat_logits_ = tf.nn.softmax(flat_logits)[:, index]
116
            flat_labels_ = flat_labels[:, index]
117
118
            inse = tf.reduce_sum(flat_logits_*flat_labels_)
119
            l = tf.reduce_sum(flat_logits_*flat_logits_)
120
            r = tf.reduce_sum(flat_labels_*flat_labels_)
121
            dice = 2 *(inse) / (l+r)
122
            dice = tf.clip_by_value(dice,0,1-1e-10)
123
124
            dice_multi += dice
125
126
        loss = (n_classes-1)*1.0-dice_multi
127
128
    return loss
129
130
def _get_acc(logits, batch_y, cost_name='dice', add_regularizers=None, class_weights=None):
131
    flat_logits = tf.reshape(logits, [-1, num_classes])
132
    flat_labels = tf.reshape(batch_y, [-1, num_classes])
133
134
    correct_prediction = tf.equal(tf.argmax(flat_logits,1), tf.argmax(flat_labels,1))
135
    correct_prediction = tf.boolean_mask(correct_prediction, tf.equal(flat_labels[:,0],0))
136
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
137
138
    return accuracy
139
140
def _get_optimizer(start_learning_rate=0.0001, global_step=0, decay_steps=25, decay_rate=0.9):
141
    learning_rate = tf.train.exponential_decay(start_learning_rate,
142
                                               global_step,
143
                                               decay_steps,
144
                                               decay_rate,
145
                                               staircase=FLAGS.staircase)
146
    tf.summary.scalar('learning rate', learning_rate)
147
    optimizer=tf.train.RMSPropOptimizer(learning_rate=learning_rate, decay=0.995)
148
    return optimizer
149
    
150
def main(argv=None):
151
    # if no output path is given, create a new folder using flags
152
    res = 'res' if FLAGS.resnet else 'nores'
153
    if FLAGS.output_path is None:
154
        FLAGS.output_path = 'TrainedModels/' + '_'.join([time.strftime('%m%d_%H%M'),
155
                                    FLAGS.model,'wceSwitch%.2fDice_AccVal'%(switchAccuracy),
156
                                    res,
157
                                    FLAGS.loss, 
158
                                    'no' + str(FLAGS.noImages),
159
                                    'reso' + str(FLAGS.reso), 
160
                                    'features' + str(FLAGS.feature),
161
                                    'lr' + '{:.1e}'.format(FLAGS.lr), 
162
                                    'dr' + str(FLAGS.dropout)])
163
164
    if not os.path.exists(FLAGS.output_path):
165
        os.makedirs(FLAGS.output_path)
166
        
167
    # save flags into file
168
    with open(FLAGS.output_path + '/flags.txt', 'a') as f:
169
        f.write(str(FLAGS.flag_values_dict()))
170
171
    # set seeds for tensorflow and numpy
172
    tf.set_random_seed(FLAGS.seed)
173
    np.random.seed(FLAGS.seed)
174
    
175
    # placeholders
176
    batch_x = tf.placeholder(tf.float32, shape=(None, FLAGS.reso, FLAGS.reso, FLAGS.slices, 1), name='batch_x')
177
    batch_y = tf.placeholder(tf.float32, shape=(None, None, None, None, num_classes))
178
    
179
    keep_prob = tf.placeholder(tf.float32, shape=[], name='keep_prob')
180
    global_step = tf.placeholder(tf.int32, shape=[])
181
    class_weights = tf.placeholder(tf.float32, shape=(num_classes))
182
183
    # choose the model
184
    inference_raw = {'4unet': models.inference_unet4, # the original architecture and use 4 layers
185
                     '4atrous248': partial(models.inference_atrous4, dilation_rates=[2,4,8])}[FLAGS.model]
186
187
    inference = partial(inference_raw, resnet=FLAGS.resnet)
188
189
    # get score and probability, add to summary
190
    score = inference(batch_x, features_root=FLAGS.feature, keep_prob=keep_prob, n_class=num_classes)
191
    logits = tf.nn.softmax(score)
192
193
    # get losses
194
    dice_cost = _get_cost(score, batch_y, cost_name='dice_multi')
195
    tf.summary.scalar('dice_loss', dice_cost)  
196
    dice_cost_noBG = _get_cost(score, batch_y, cost_name='dice_multi_noBG')
197
    tf.summary.scalar('dice_loss noBG', dice_cost_noBG)   
198
199
    cross_entropy = _get_cost(score, batch_y, cost_name='cross_entropy')
200
    tf.summary.scalar('cross_entropy', cross_entropy) 
201
202
    weighted_cross_entropy = _get_cost(score, batch_y, cost_name='cross_entropy', class_weights=class_weights)
203
    tf.summary.scalar('weighted_cross_entropy',  weighted_cross_entropy)     
204
205
    if FLAGS.loss == 'wce': # weighted cross entropy
206
        cost = weighted_cross_entropy
207
    elif FLAGS.loss == 'dice': # dice
208
        cost = dice_cost
209
    elif FLAGS.loss == 'ce': # cross entropy
210
        cost = cross_entropy
211
    else:
212
        cost = dice_cost
213
214
    # get accuracy
215
    accuracy = _get_acc(score, batch_y)
216
217
    # set optimizer with learning rate and decay rate
218
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
219
    with tf.control_dependencies(update_ops):
220
        with tf.name_scope('rms_optimizer'):
221
            optimizer = _get_optimizer(FLAGS.lr, global_step, decay_rate=FLAGS.dr)
222
            optimizer_dice = _get_optimizer(FLAGS.lr, global_step, decay_rate=FLAGS.dr)
223
        
224
            grads = optimizer.compute_gradients(cost)
225
            grads_dice = optimizer_dice.compute_gradients(dice_cost)
226
227
            train = optimizer.apply_gradients(grads)
228
            train_dice = optimizer_dice.apply_gradients(grads_dice)   
229
230
    # get merged summaries
231
    merged = tf.summary.merge_all()
232
233
    # get losses & acc for training
234
    dice_cost_train = tf.placeholder(tf.float32, shape=[])
235
    dice_loss_train_summary = tf.summary.scalar('dice_loss_train', dice_cost_train)    
236
237
    cross_entropy_train = tf.placeholder(tf.float32, shape=[])
238
    cross_entropy_train_summary = tf.summary.scalar('cross_entropy_train', cross_entropy_train) 
239
240
    weighted_cross_entropy_train = tf.placeholder(tf.float32, shape=[])
241
    weighted_cross_entropy_train_summary = tf.summary.scalar('weighted_cross_entropy_train',  weighted_cross_entropy_train)    
242
243
    accuracy_train = tf.placeholder(tf.float32, shape=[])
244
    accuracy_train_summary = tf.summary.scalar('accuracy_train',  accuracy_train)  
245
246
    # get losses & acc for validation
247
    dice_cost_val = tf.placeholder(tf.float32, shape=[])
248
    dice_loss_val_summary = tf.summary.scalar('dice_loss_val', dice_cost_val)    
249
250
    cross_entropy_val = tf.placeholder(tf.float32, shape=[])
251
    cross_entropy_val_summary = tf.summary.scalar('cross_entropy_val', cross_entropy_val) 
252
253
    weighted_cross_entropy_val = tf.placeholder(tf.float32, shape=[])
254
    weighted_cross_entropy_val_summary = tf.summary.scalar('weighted_cross_entropy_val',  weighted_cross_entropy_val)   
255
256
    accuracy_val = tf.placeholder(tf.float32, shape=[])
257
    accuracy_val_summary = tf.summary.scalar('accuracy_val',  accuracy_val)  
258
259
    # load data
260
    #read multiple data
261
    dataFolder = FLAGS.folder + '/train'
262
    pathNifti = Path(dataFolder)
263
264
    X = []  # create an empty list
265
    for fileList in list(pathNifti.glob('**/*.im')):
266
        X.append(fileList)
267
    X = sorted(X)
268
269
    y = []  # create an empty list
270
    for fileList in list(pathNifti.glob('**/*.seg')):
271
        y.append(fileList)
272
    y = sorted(y)
273
274
    pathNifti = Path(FLAGS.folder + '/valid')
275
276
    X_v = []  # create an empty list
277
    for fileList in list(pathNifti.glob('**/*.im')):
278
        X_v.append(fileList)
279
    X_v = sorted(X_v)
280
281
    y_v = []  # create an empty list
282
    for fileList in list(pathNifti.glob('**/*.seg')):
283
        y_v.append(fileList)
284
    y_v = sorted(y_v)
285
286
    saver = tf.train.Saver(max_to_keep=0)
287
288
    # load mri data and segmentation maps for training
289
    if FLAGS.noImages ==-1:
290
        noOfFiles = len(X)
291
    else:
292
        noOfFiles = FLAGS.noImages
293
    list_X = list( X[i] for i in range(noOfFiles) )
294
    list_y = list( y[i] for i in range(noOfFiles) )
295
296
    X_train, y_train, train_info = tfut.loadData_list_h5(list_X,list_y,num_channels)
297
    print('Dataload is done')
298
    X_train = tfut.zeroMeanUnitVariance(X_train)
299
    weights_cross_entropy = tfut.compute_weights_multiClass(y_train,num_classes)
300
    del list_X, list_y
301
302
    # load mri data and segmentation maps for validation
303
    if FLAGS.noImages ==-1:
304
        noOfFiles = len(X_v)
305
    else:
306
        noOfFiles = FLAGS.noImages
307
    
308
    list_X = list( X_v[i] for i in range(noOfFiles) )
309
    list_y = list( y_v[i] for i in range(noOfFiles) )
310
    X_val, y_val, val_info = tfut.loadData_list_h5(list_X, list_y,num_channels)
311
    X_val = tfut.zeroMeanUnitVariance(X_val)
312
    del list_X, list_y
313
314
    X_train = X_train[...,np.newaxis]
315
    X_val = X_val[...,np.newaxis]
316
317
    # # resize data
318
    if FLAGS.reso != 384:
319
        input_size= X_train.shape[2]
320
        X_train = tfut.batch_resize(X_train, input_size=input_size, output_size=FLAGS.reso, order=3)
321
        y_train = tfut.batch_resize(y_train, input_size=input_size, output_size=FLAGS.reso, order=0)
322
323
        X_val = tfut.batch_resize(X_val, input_size=input_size, output_size=FLAGS.reso, order=3)
324
        y_val = tfut.batch_resize(y_val, input_size=input_size, output_size=FLAGS.reso, order=0)
325
326
    sample_size = X_train.shape[0]
327
    val_size = X_val.shape[0]
328
329
    # initialization for early stopping
330
    if FLAGS.early_stopping:
331
        best_acc = 0
332
        wait = 0
333
        patience = 500
334
        switchFlag = 1
335
336
    config = tf.ConfigProto()
337
    config.log_device_placement=False
338
    config.allow_soft_placement =True
339
    from tensorflow.python.client import device_lib
340
341
    with tf.Session(config=config) as sess:
342
        sess.run(tf.global_variables_initializer())
343
344
        modelNo = 0
345
        if FLAGS.restore:
346
            ckpt = tf.train.get_checkpoint_state(FLAGS.output_path)
347
            model_path = ckpt.model_checkpoint_path
348
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.output_path))
349
            print('Model restored from file: %s' % model_path)
350
            tmp=re.findall('\d+', model_path)
351
            modelNo = int(tmp[-1])
352
353
        train_writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph)
354
355
        start = time.clock()
356
357
        prediction = sess.run(score, feed_dict={batch_x: X_train[0:1], 
358
                                        batch_y: y_train[0:1],
359
                                        global_step:0,
360
                                        keep_prob:FLAGS.dropout,
361
                                        class_weights:weights_cross_entropy})
362
        pred_shape = prediction.shape
363
364
        offset0 = (y_train.shape[1] - pred_shape[1]) // 2
365
        offset1 = (y_train.shape[2] - pred_shape[2]) // 2
366
        offset2 = (y_train.shape[3] - pred_shape[3]) // 2
367
368
        if offset0 == 0 and offset1 == 0 and offset2 == 0:
369
            print('SAME padding')
370
        else:
371
            y_train = y_train[:, offset0:(-offset0), offset1:(-offset1),offset2:(-offset2),:]
372
            y_val = y_val[:, offset0:(-offset0), offset1:(-offset1),offset2:(-offset2),:]
373
374
        for epoch in range(modelNo+1, FLAGS.epoch+1):
375
            print('train epoch', epoch, 'sample_size', sample_size) 
376
377
            # shuffle data at the beginning of every epoch
378
            shuffled_idx = np.random.permutation(sample_size)
379
            wce_train, dice_train, ce_train, acc_train = [], [], [], []
380
            for j in range(sample_size):
381
                idx = shuffled_idx[j]
382
                i = (epoch - 1) * sample_size + j + 1
383
384
                # Whether to do left-right mirroring
385
                step = np.random.choice([1,-1]) 
386
387
                if switchFlag: 
388
                    _, loss, dice_loss, cross_entropy_loss, acc = sess.run([train, weighted_cross_entropy, dice_cost, cross_entropy, accuracy], 
389
                                                                        feed_dict={batch_x: X_train[idx:idx+1, :, :, ::step, :], 
390
                                                                                    batch_y: y_train[idx:idx+1, :, :, ::step, :],
391
                                                                                    global_step:epoch-1,
392
                                                                                    keep_prob:FLAGS.dropout,
393
                                                                                    class_weights:weights_cross_entropy})
394
                else:
395
                     _, loss, dice_loss, cross_entropy_loss, acc = sess.run([train_dice, weighted_cross_entropy, dice_cost, cross_entropy, accuracy], 
396
                                                                        feed_dict={batch_x: X_train[idx:idx+1, :, :, ::step, :], 
397
                                                                                    batch_y: y_train[idx:idx+1, :, :, ::step, :],
398
                                                                                    global_step:epoch-1,
399
                                                                                    keep_prob:FLAGS.dropout,
400
                                                                                    class_weights:weights_cross_entropy})
401
402
                wce_train.append(loss)
403
                dice_train.append(dice_loss)
404
                ce_train.append(cross_entropy_loss)
405
                acc_train.append(acc)
406
407
            # swithc to dice loss when the CE train accuracy is pretty good
408
            if np.mean(acc_train) > switchAccuracy:
409
                switchFlag = 0
410
                print('@@@@ switchtoDicein Epoch#:' ,epoch )
411
412
            print('training weighted loss:', np.mean(wce_train), \
413
                    ', cross entropy loss:', np.mean(ce_train), \
414
                    ', dice loss:', np.mean(dice_train), \
415
                    ', accuracy:', np.mean(acc_train))
416
            summary = sess.run(weighted_cross_entropy_train_summary, feed_dict={weighted_cross_entropy_train:np.mean(wce_train)})
417
            train_writer.add_summary(summary, epoch)
418
            summary = sess.run(dice_loss_train_summary, feed_dict={dice_cost_train:np.mean(dice_train)})
419
            train_writer.add_summary(summary, epoch)
420
            summary = sess.run(cross_entropy_train_summary, feed_dict={cross_entropy_train:np.mean(ce_train)})
421
            train_writer.add_summary(summary, epoch)
422
            summary = sess.run(accuracy_train_summary , feed_dict={accuracy_train:np.mean(acc_train)})
423
            train_writer.add_summary(summary, epoch)
424
425
            if FLAGS.val:
426
                summary = sess.run(merged, 
427
                                    feed_dict={batch_x: X_train[:1],
428
                                                batch_y: y_train[:1],
429
                                                global_step:epoch-1,
430
                                                keep_prob:1.0,
431
                                                class_weights:weights_cross_entropy})
432
                train_writer.add_summary(summary, epoch)
433
434
                wce_val, dice_val, ce_val, acc_val = [], [], [], []
435
                for j in range(val_size):
436
                    loss, dice_loss, cross_entropy_loss, acc = sess.run([weighted_cross_entropy, dice_cost, cross_entropy, accuracy], 
437
                                                                            feed_dict={batch_x: X_val[j:j+1], 
438
                                                                                        batch_y: y_val[j:j+1],
439
                                                                                        global_step:epoch-1,
440
                                                                                        keep_prob:1.0,
441
                                                                                        class_weights:weights_cross_entropy})
442
                    wce_val.append(loss)
443
                    dice_val.append(dice_loss)
444
                    ce_val.append(cross_entropy_loss)
445
                    acc_val.append(acc)
446
447
                summary = sess.run(weighted_cross_entropy_val_summary, feed_dict={weighted_cross_entropy_val:np.mean(wce_val)})
448
                train_writer.add_summary(summary, epoch)
449
                summary = sess.run(dice_loss_val_summary, feed_dict={dice_cost_val:np.mean(dice_val)})
450
                train_writer.add_summary(summary, epoch)
451
                summary = sess.run(cross_entropy_val_summary, feed_dict={cross_entropy_val:np.mean(ce_val)})
452
                train_writer.add_summary(summary, epoch)
453
                summary = sess.run(accuracy_val_summary, feed_dict={accuracy_val:np.mean(acc_val)})
454
                train_writer.add_summary(summary, epoch)  
455
456
                print('validation weighted loss:', np.mean(wce_val), \
457
                    ', cross entropy loss:', np.mean(ce_val), \
458
                    ', dice loss:', np.mean(dice_val), \
459
                    ', accuracy:', np.mean(acc_val))
460
461
                acc = np.mean(acc_val)
462
                if  acc - 1e-18 > best_acc:
463
                    best_acc, wait = acc, 0
464
                    saver.save(sess, FLAGS.output_path+'/model')
465
                    with open(FLAGS.output_path + '/SavedEpochNo.txt', 'w') as f:
466
                        f.write(str(epoch))
467
                else:
468
                    saver.save(sess, FLAGS.output_path+'/model_lastEpoch')
469
                    with open(FLAGS.output_path + '/SavedEpochNoLastEpoch.txt', 'w') as f:
470
                        f.write(str(epoch))
471
                    wait += 1
472
                    if wait > patience:
473
                        print("!!!!Early Stopping on EPOCH %d!!!!" % epoch)
474
                        break
475
                print("!!!!BEST: %f, wait %d !!!"%(best_acc, wait))
476
477
478
if __name__ == '__main__':
479
    tf.app.run()