Diff of /train.py [000000] .. [1b6491]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,479 @@
+# ==============================================================================
+# Copyright (C) 2020 Vladimir Juras, Ravinder Regatte and Cem M. Deniz
+#
+# This file is part of 2019_IWOAI_Challenge
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+# ==============================================================================
+import tensorflow as tf
+import tf_utilities as tfut
+import tf_layers as tflay
+import models
+import sys
+
+import numpy as np
+import re
+import time
+import os
+from functools import partial
+
+import h5py
+from sklearn.model_selection import StratifiedKFold
+from sklearn.metrics import accuracy_score
+from sklearn.preprocessing import label_binarize
+from keras.utils import to_categorical
+from pathlib import Path
+
+
+tf.app.flags.DEFINE_boolean('restore', False, 'Whether to restore from previous model.')
+tf.app.flags.DEFINE_float('lr', 0.00005, 'Initial learning rate.')
+tf.app.flags.DEFINE_integer('feature', 16, 'Number of root features.')
+tf.app.flags.DEFINE_string('model', '4atrous248', 'Model name.')
+tf.app.flags.DEFINE_boolean('val', True, 'Whether to use validation.')
+tf.app.flags.DEFINE_boolean('full_data', True, 'Whether to use full data set.')
+tf.app.flags.DEFINE_float('dr', 1.0, 'Learning rate decay rate.')
+tf.app.flags.DEFINE_integer('reso', 384, 'Image size.')
+tf.app.flags.DEFINE_integer('slices', 160, 'Number Of Slices')
+tf.app.flags.DEFINE_string('loss', 'wce', 'Loss name.')
+tf.app.flags.DEFINE_integer('epoch', 400, 'Number of epochs.')
+tf.app.flags.DEFINE_boolean('staircase', False, 'If True decay the learning rate at discrete intervals.')
+tf.app.flags.DEFINE_integer('seed', 1234, 'Graph-level random seed.')
+tf.app.flags.DEFINE_float('dropout', 1.0, 'Dropout rate when training.')
+tf.app.flags.DEFINE_string('output_path', None, 'Name of output folder.')
+tf.app.flags.DEFINE_boolean('resnet', False, 'Whether to use resnet shortcut.')
+tf.app.flags.DEFINE_boolean('early_stopping', True, 'early stopping feature')
+tf.app.flags.DEFINE_string('folder', './data', 'Data Folder')
+tf.app.flags.DEFINE_integer('noImages', -1, 'how many images to train and validate')
+tf.app.flags.DEFINE_float('switchAccuracy', 0.88, 'Training accuracy switch to Dice loss')
+tf.app.flags.DEFINE_string('info', ' ', 'add some info to run')
+
+FLAGS = tf.app.flags.FLAGS
+
+switchAccuracy = FLAGS.switchAccuracy
+
+num_classes = 7
+num_channels = 1
+
+def _get_cost(logits, batch_y, cost_name='dice', add_regularizers=None, class_weights=None):
+    flat_logits = tf.reshape(logits, [-1, num_classes])
+    flat_labels = tf.reshape(batch_y, [-1, num_classes])
+    
+    if cost_name == 'cross_entropy':
+        if class_weights is not None:
+            weight_map = tf.multiply(flat_labels, class_weights)
+            weight_map = tf.reduce_sum(weight_map, axis=1)
+            loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
+                                                            labels=flat_labels)
+            weighted_loss = tf.multiply(loss_map, weight_map)
+            loss = tf.reduce_mean(weighted_loss)
+        else:
+            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels))
+
+    elif cost_name == 'dice':
+        flat_logits = tf.nn.softmax(flat_logits)[:, 1]
+        flat_labels = flat_labels[:, 1]
+
+        inse = tf.reduce_sum(flat_logits*flat_labels)
+        l = tf.reduce_sum(flat_logits*flat_logits)
+        r = tf.reduce_sum(flat_labels*flat_labels)
+        dice = 2 *(inse) / (l+r)
+        loss = 1.0-tf.clip_by_value(dice,0,1-1e-10)
+
+    elif cost_name == 'dice_multi':
+        dice_multi = 0
+        n_classes = num_classes
+        for index in range(n_classes):
+            flat_logits_ = tf.nn.softmax(flat_logits)[:, index]
+            flat_labels_ = flat_labels[:, index]
+
+            inse = tf.reduce_sum(flat_logits_*flat_labels_)
+            l = tf.reduce_sum(flat_logits_*flat_logits_)
+            r = tf.reduce_sum(flat_labels_*flat_labels_)
+            dice = 2 *(inse) / (l+r)
+            dice = tf.clip_by_value(dice,0,1-1e-10)
+
+            dice_multi += dice
+
+        loss = n_classes*1.0-dice_multi
+
+
+    elif cost_name == 'dice_multi_noBG':
+        dice_multi = 0
+        n_classes = num_classes
+        for index in range(1,n_classes):
+            flat_logits_ = tf.nn.softmax(flat_logits)[:, index]
+            flat_labels_ = flat_labels[:, index]
+
+            inse = tf.reduce_sum(flat_logits_*flat_labels_)
+            l = tf.reduce_sum(flat_logits_*flat_logits_)
+            r = tf.reduce_sum(flat_labels_*flat_labels_)
+            dice = 2 *(inse) / (l+r)
+            dice = tf.clip_by_value(dice,0,1-1e-10)
+
+            dice_multi += dice
+
+        loss = (n_classes-1)*1.0-dice_multi
+
+    return loss
+
+def _get_acc(logits, batch_y, cost_name='dice', add_regularizers=None, class_weights=None):
+    flat_logits = tf.reshape(logits, [-1, num_classes])
+    flat_labels = tf.reshape(batch_y, [-1, num_classes])
+
+    correct_prediction = tf.equal(tf.argmax(flat_logits,1), tf.argmax(flat_labels,1))
+    correct_prediction = tf.boolean_mask(correct_prediction, tf.equal(flat_labels[:,0],0))
+    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+
+    return accuracy
+
+def _get_optimizer(start_learning_rate=0.0001, global_step=0, decay_steps=25, decay_rate=0.9):
+    learning_rate = tf.train.exponential_decay(start_learning_rate,
+                                               global_step,
+                                               decay_steps,
+                                               decay_rate,
+                                               staircase=FLAGS.staircase)
+    tf.summary.scalar('learning rate', learning_rate)
+    optimizer=tf.train.RMSPropOptimizer(learning_rate=learning_rate, decay=0.995)
+    return optimizer
+    
+def main(argv=None):
+    # if no output path is given, create a new folder using flags
+    res = 'res' if FLAGS.resnet else 'nores'
+    if FLAGS.output_path is None:
+        FLAGS.output_path = 'TrainedModels/' + '_'.join([time.strftime('%m%d_%H%M'),
+                                    FLAGS.model,'wceSwitch%.2fDice_AccVal'%(switchAccuracy),
+                                    res,
+                                    FLAGS.loss, 
+                                    'no' + str(FLAGS.noImages),
+                                    'reso' + str(FLAGS.reso), 
+                                    'features' + str(FLAGS.feature),
+                                    'lr' + '{:.1e}'.format(FLAGS.lr), 
+                                    'dr' + str(FLAGS.dropout)])
+
+    if not os.path.exists(FLAGS.output_path):
+        os.makedirs(FLAGS.output_path)
+        
+    # save flags into file
+    with open(FLAGS.output_path + '/flags.txt', 'a') as f:
+        f.write(str(FLAGS.flag_values_dict()))
+
+    # set seeds for tensorflow and numpy
+    tf.set_random_seed(FLAGS.seed)
+    np.random.seed(FLAGS.seed)
+    
+    # placeholders
+    batch_x = tf.placeholder(tf.float32, shape=(None, FLAGS.reso, FLAGS.reso, FLAGS.slices, 1), name='batch_x')
+    batch_y = tf.placeholder(tf.float32, shape=(None, None, None, None, num_classes))
+    
+    keep_prob = tf.placeholder(tf.float32, shape=[], name='keep_prob')
+    global_step = tf.placeholder(tf.int32, shape=[])
+    class_weights = tf.placeholder(tf.float32, shape=(num_classes))
+
+    # choose the model
+    inference_raw = {'4unet': models.inference_unet4, # the original architecture and use 4 layers
+                     '4atrous248': partial(models.inference_atrous4, dilation_rates=[2,4,8])}[FLAGS.model]
+
+    inference = partial(inference_raw, resnet=FLAGS.resnet)
+
+    # get score and probability, add to summary
+    score = inference(batch_x, features_root=FLAGS.feature, keep_prob=keep_prob, n_class=num_classes)
+    logits = tf.nn.softmax(score)
+
+    # get losses
+    dice_cost = _get_cost(score, batch_y, cost_name='dice_multi')
+    tf.summary.scalar('dice_loss', dice_cost)  
+    dice_cost_noBG = _get_cost(score, batch_y, cost_name='dice_multi_noBG')
+    tf.summary.scalar('dice_loss noBG', dice_cost_noBG)   
+
+    cross_entropy = _get_cost(score, batch_y, cost_name='cross_entropy')
+    tf.summary.scalar('cross_entropy', cross_entropy) 
+
+    weighted_cross_entropy = _get_cost(score, batch_y, cost_name='cross_entropy', class_weights=class_weights)
+    tf.summary.scalar('weighted_cross_entropy',  weighted_cross_entropy)     
+
+    if FLAGS.loss == 'wce': # weighted cross entropy
+        cost = weighted_cross_entropy
+    elif FLAGS.loss == 'dice': # dice
+        cost = dice_cost
+    elif FLAGS.loss == 'ce': # cross entropy
+        cost = cross_entropy
+    else:
+        cost = dice_cost
+
+    # get accuracy
+    accuracy = _get_acc(score, batch_y)
+
+    # set optimizer with learning rate and decay rate
+    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+    with tf.control_dependencies(update_ops):
+        with tf.name_scope('rms_optimizer'):
+            optimizer = _get_optimizer(FLAGS.lr, global_step, decay_rate=FLAGS.dr)
+            optimizer_dice = _get_optimizer(FLAGS.lr, global_step, decay_rate=FLAGS.dr)
+        
+            grads = optimizer.compute_gradients(cost)
+            grads_dice = optimizer_dice.compute_gradients(dice_cost)
+
+            train = optimizer.apply_gradients(grads)
+            train_dice = optimizer_dice.apply_gradients(grads_dice)	  
+
+    # get merged summaries
+    merged = tf.summary.merge_all()
+
+    # get losses & acc for training
+    dice_cost_train = tf.placeholder(tf.float32, shape=[])
+    dice_loss_train_summary = tf.summary.scalar('dice_loss_train', dice_cost_train)    
+
+    cross_entropy_train = tf.placeholder(tf.float32, shape=[])
+    cross_entropy_train_summary = tf.summary.scalar('cross_entropy_train', cross_entropy_train) 
+
+    weighted_cross_entropy_train = tf.placeholder(tf.float32, shape=[])
+    weighted_cross_entropy_train_summary = tf.summary.scalar('weighted_cross_entropy_train',  weighted_cross_entropy_train)    
+
+    accuracy_train = tf.placeholder(tf.float32, shape=[])
+    accuracy_train_summary = tf.summary.scalar('accuracy_train',  accuracy_train)  
+
+    # get losses & acc for validation
+    dice_cost_val = tf.placeholder(tf.float32, shape=[])
+    dice_loss_val_summary = tf.summary.scalar('dice_loss_val', dice_cost_val)    
+
+    cross_entropy_val = tf.placeholder(tf.float32, shape=[])
+    cross_entropy_val_summary = tf.summary.scalar('cross_entropy_val', cross_entropy_val) 
+
+    weighted_cross_entropy_val = tf.placeholder(tf.float32, shape=[])
+    weighted_cross_entropy_val_summary = tf.summary.scalar('weighted_cross_entropy_val',  weighted_cross_entropy_val)   
+
+    accuracy_val = tf.placeholder(tf.float32, shape=[])
+    accuracy_val_summary = tf.summary.scalar('accuracy_val',  accuracy_val)  
+
+    # load data
+    #read multiple data
+    dataFolder = FLAGS.folder + '/train'
+    pathNifti = Path(dataFolder)
+
+    X = []  # create an empty list
+    for fileList in list(pathNifti.glob('**/*.im')):
+        X.append(fileList)
+    X = sorted(X)
+
+    y = []  # create an empty list
+    for fileList in list(pathNifti.glob('**/*.seg')):
+        y.append(fileList)
+    y = sorted(y)
+
+    pathNifti = Path(FLAGS.folder + '/valid')
+
+    X_v = []  # create an empty list
+    for fileList in list(pathNifti.glob('**/*.im')):
+        X_v.append(fileList)
+    X_v = sorted(X_v)
+
+    y_v = []  # create an empty list
+    for fileList in list(pathNifti.glob('**/*.seg')):
+        y_v.append(fileList)
+    y_v = sorted(y_v)
+
+    saver = tf.train.Saver(max_to_keep=0)
+
+    # load mri data and segmentation maps for training
+    if FLAGS.noImages ==-1:
+        noOfFiles = len(X)
+    else:
+        noOfFiles = FLAGS.noImages
+    list_X = list( X[i] for i in range(noOfFiles) )
+    list_y = list( y[i] for i in range(noOfFiles) )
+
+    X_train, y_train, train_info = tfut.loadData_list_h5(list_X,list_y,num_channels)
+    print('Dataload is done')
+    X_train = tfut.zeroMeanUnitVariance(X_train)
+    weights_cross_entropy = tfut.compute_weights_multiClass(y_train,num_classes)
+    del list_X, list_y
+
+    # load mri data and segmentation maps for validation
+    if FLAGS.noImages ==-1:
+        noOfFiles = len(X_v)
+    else:
+        noOfFiles = FLAGS.noImages
+    
+    list_X = list( X_v[i] for i in range(noOfFiles) )
+    list_y = list( y_v[i] for i in range(noOfFiles) )
+    X_val, y_val, val_info = tfut.loadData_list_h5(list_X, list_y,num_channels)
+    X_val = tfut.zeroMeanUnitVariance(X_val)
+    del list_X, list_y
+
+    X_train = X_train[...,np.newaxis]
+    X_val = X_val[...,np.newaxis]
+
+    # # resize data
+    if FLAGS.reso != 384:
+        input_size= X_train.shape[2]
+        X_train = tfut.batch_resize(X_train, input_size=input_size, output_size=FLAGS.reso, order=3)
+        y_train = tfut.batch_resize(y_train, input_size=input_size, output_size=FLAGS.reso, order=0)
+
+        X_val = tfut.batch_resize(X_val, input_size=input_size, output_size=FLAGS.reso, order=3)
+        y_val = tfut.batch_resize(y_val, input_size=input_size, output_size=FLAGS.reso, order=0)
+
+    sample_size = X_train.shape[0]
+    val_size = X_val.shape[0]
+
+    # initialization for early stopping
+    if FLAGS.early_stopping:
+        best_acc = 0
+        wait = 0
+        patience = 500
+        switchFlag = 1
+
+    config = tf.ConfigProto()
+    config.log_device_placement=False
+    config.allow_soft_placement =True
+    from tensorflow.python.client import device_lib
+
+    with tf.Session(config=config) as sess:
+        sess.run(tf.global_variables_initializer())
+
+        modelNo = 0
+        if FLAGS.restore:
+            ckpt = tf.train.get_checkpoint_state(FLAGS.output_path)
+            model_path = ckpt.model_checkpoint_path
+            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.output_path))
+            print('Model restored from file: %s' % model_path)
+            tmp=re.findall('\d+', model_path)
+            modelNo = int(tmp[-1])
+
+        train_writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph)
+
+        start = time.clock()
+
+        prediction = sess.run(score, feed_dict={batch_x: X_train[0:1], 
+                                        batch_y: y_train[0:1],
+                                        global_step:0,
+                                        keep_prob:FLAGS.dropout,
+                                        class_weights:weights_cross_entropy})
+        pred_shape = prediction.shape
+
+        offset0 = (y_train.shape[1] - pred_shape[1]) // 2
+        offset1 = (y_train.shape[2] - pred_shape[2]) // 2
+        offset2 = (y_train.shape[3] - pred_shape[3]) // 2
+
+        if offset0 == 0 and offset1 == 0 and offset2 == 0:
+            print('SAME padding')
+        else:
+            y_train = y_train[:, offset0:(-offset0), offset1:(-offset1),offset2:(-offset2),:]
+            y_val = y_val[:, offset0:(-offset0), offset1:(-offset1),offset2:(-offset2),:]
+
+        for epoch in range(modelNo+1, FLAGS.epoch+1):
+            print('train epoch', epoch, 'sample_size', sample_size) 
+
+            # shuffle data at the beginning of every epoch
+            shuffled_idx = np.random.permutation(sample_size)
+            wce_train, dice_train, ce_train, acc_train = [], [], [], []
+            for j in range(sample_size):
+                idx = shuffled_idx[j]
+                i = (epoch - 1) * sample_size + j + 1
+
+                # Whether to do left-right mirroring
+                step = np.random.choice([1,-1]) 
+
+                if switchFlag: 
+                    _, loss, dice_loss, cross_entropy_loss, acc = sess.run([train, weighted_cross_entropy, dice_cost, cross_entropy, accuracy], 
+                                                                        feed_dict={batch_x: X_train[idx:idx+1, :, :, ::step, :], 
+                                                                                    batch_y: y_train[idx:idx+1, :, :, ::step, :],
+                                                                                    global_step:epoch-1,
+                                                                                    keep_prob:FLAGS.dropout,
+                                                                                    class_weights:weights_cross_entropy})
+                else:
+                     _, loss, dice_loss, cross_entropy_loss, acc = sess.run([train_dice, weighted_cross_entropy, dice_cost, cross_entropy, accuracy], 
+                                                                        feed_dict={batch_x: X_train[idx:idx+1, :, :, ::step, :], 
+                                                                                    batch_y: y_train[idx:idx+1, :, :, ::step, :],
+                                                                                    global_step:epoch-1,
+                                                                                    keep_prob:FLAGS.dropout,
+                                                                                    class_weights:weights_cross_entropy})
+
+                wce_train.append(loss)
+                dice_train.append(dice_loss)
+                ce_train.append(cross_entropy_loss)
+                acc_train.append(acc)
+
+            # swithc to dice loss when the CE train accuracy is pretty good
+            if np.mean(acc_train) > switchAccuracy:
+                switchFlag = 0
+                print('@@@@ switchtoDicein Epoch#:' ,epoch )
+
+            print('training weighted loss:', np.mean(wce_train), \
+                    ', cross entropy loss:', np.mean(ce_train), \
+                    ', dice loss:', np.mean(dice_train), \
+                    ', accuracy:', np.mean(acc_train))
+            summary = sess.run(weighted_cross_entropy_train_summary, feed_dict={weighted_cross_entropy_train:np.mean(wce_train)})
+            train_writer.add_summary(summary, epoch)
+            summary = sess.run(dice_loss_train_summary, feed_dict={dice_cost_train:np.mean(dice_train)})
+            train_writer.add_summary(summary, epoch)
+            summary = sess.run(cross_entropy_train_summary, feed_dict={cross_entropy_train:np.mean(ce_train)})
+            train_writer.add_summary(summary, epoch)
+            summary = sess.run(accuracy_train_summary , feed_dict={accuracy_train:np.mean(acc_train)})
+            train_writer.add_summary(summary, epoch)
+
+            if FLAGS.val:
+                summary = sess.run(merged, 
+                                    feed_dict={batch_x: X_train[:1],
+                                                batch_y: y_train[:1],
+                                                global_step:epoch-1,
+                                                keep_prob:1.0,
+                                                class_weights:weights_cross_entropy})
+                train_writer.add_summary(summary, epoch)
+
+                wce_val, dice_val, ce_val, acc_val = [], [], [], []
+                for j in range(val_size):
+                    loss, dice_loss, cross_entropy_loss, acc = sess.run([weighted_cross_entropy, dice_cost, cross_entropy, accuracy], 
+                                                                            feed_dict={batch_x: X_val[j:j+1], 
+                                                                                        batch_y: y_val[j:j+1],
+                                                                                        global_step:epoch-1,
+                                                                                        keep_prob:1.0,
+                                                                                        class_weights:weights_cross_entropy})
+                    wce_val.append(loss)
+                    dice_val.append(dice_loss)
+                    ce_val.append(cross_entropy_loss)
+                    acc_val.append(acc)
+
+                summary = sess.run(weighted_cross_entropy_val_summary, feed_dict={weighted_cross_entropy_val:np.mean(wce_val)})
+                train_writer.add_summary(summary, epoch)
+                summary = sess.run(dice_loss_val_summary, feed_dict={dice_cost_val:np.mean(dice_val)})
+                train_writer.add_summary(summary, epoch)
+                summary = sess.run(cross_entropy_val_summary, feed_dict={cross_entropy_val:np.mean(ce_val)})
+                train_writer.add_summary(summary, epoch)
+                summary = sess.run(accuracy_val_summary, feed_dict={accuracy_val:np.mean(acc_val)})
+                train_writer.add_summary(summary, epoch)  
+
+                print('validation weighted loss:', np.mean(wce_val), \
+                    ', cross entropy loss:', np.mean(ce_val), \
+                    ', dice loss:', np.mean(dice_val), \
+                    ', accuracy:', np.mean(acc_val))
+
+                acc = np.mean(acc_val)
+                if  acc - 1e-18 > best_acc:
+                    best_acc, wait = acc, 0
+                    saver.save(sess, FLAGS.output_path+'/model')
+                    with open(FLAGS.output_path + '/SavedEpochNo.txt', 'w') as f:
+                        f.write(str(epoch))
+                else:
+                    saver.save(sess, FLAGS.output_path+'/model_lastEpoch')
+                    with open(FLAGS.output_path + '/SavedEpochNoLastEpoch.txt', 'w') as f:
+                        f.write(str(epoch))
+                    wait += 1
+                    if wait > patience:
+                        print("!!!!Early Stopping on EPOCH %d!!!!" % epoch)
+                        break
+                print("!!!!BEST: %f, wait %d !!!"%(best_acc, wait))
+
+
+if __name__ == '__main__':
+    tf.app.run()