Diff of /2D/train_main_2d_patch.py [000000] .. [c9b969]

Switch to side-by-side view

--- a
+++ b/2D/train_main_2d_patch.py
@@ -0,0 +1,296 @@
+
+from __future__ import print_function
+
+# import packages
+from functools import partial
+import os
+import numpy as np
+from keras.models import Model
+from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
+from keras.optimizers import Adam
+from keras import callbacks
+from keras import backend as K
+from keras.utils import plot_model
+
+# import load data
+from data_handling_2d_patch import load_train_data, load_validatation_data
+
+# import configurations
+import configs
+
+K.set_image_data_format('channels_last')  # TF dimension ordering in this code
+image_type = configs.IMAGE_TYPE
+
+# init configs
+image_rows = configs.VOLUME_ROWS
+image_cols = configs.VOLUME_COLS
+image_depth = configs.VOLUME_DEPS
+num_classes = configs.NUM_CLASSES
+
+# patch extraction parameters
+patch_size = configs.PATCH_SIZE
+BASE = configs.BASE
+smooth = configs.SMOOTH
+nb_epochs  = configs.NUM_EPOCHS
+batch_size  = configs.BATCH_SIZE
+unet_model_type = configs.MODEL
+PATIENCE = configs.PATIENCE
+
+# compute dsc
+def dice_coef(y_true, y_pred, smooth=1.):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+
+# proposed loss function
+def dice_coef_loss(y_true, y_pred):
+    distance = 0
+    for label_index in range(num_classes):
+        dice_coef_class = dice_coef(y_true[:,:,:,label_index], y_pred[:, :,:,label_index])
+        distance = 1 - dice_coef_class + distance
+    return distance
+
+# dsc per class
+def label_wise_dice_coefficient(y_true, y_pred, label_index):
+    return dice_coef(y_true[:,:,:,label_index], y_pred[:, :,:,label_index])
+
+# get label dsc
+def get_label_dice_coefficient_function(label_index):
+    f = partial(label_wise_dice_coefficient, label_index=label_index)
+    f.__setattr__('__name__', 'label_{0}_dice_coef'.format(label_index))
+    return f
+
+# 2D U-net depth=5
+def get_unet_default():
+    metrics = dice_coef
+    include_label_wise_dice_coefficients = True;
+    inputs = Input((patch_size, patch_size, 1))
+    conv1 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(inputs)
+    conv1 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(conv1)
+    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+    conv2 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(pool1)
+    conv2 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(conv2)
+    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+    conv3 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(pool2)
+    conv3 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(conv3)
+    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+    conv4 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(pool3)
+    conv4 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(conv4)
+    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
+
+    conv5 = Conv2D(BASE*16, (3, 3), activation='relu', padding='same')(pool4)
+    conv5 = Conv2D(BASE*16, (3, 3), activation='relu', padding='same')(conv5)
+
+    up6 = concatenate([Conv2DTranspose(BASE*8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
+    conv6 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(up6)
+    conv6 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(conv6)
+
+    up7 = concatenate([Conv2DTranspose(BASE*4, (2, 2),strides=(2, 2), padding='same')(conv6), conv3], axis=3)
+    conv7 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(up7)
+    conv7 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(conv7)
+
+    up8 = concatenate([Conv2DTranspose(BASE*2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
+    conv8 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(up8)
+    conv8 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(conv8)
+
+    up9 = concatenate([Conv2DTranspose(BASE, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
+    conv9 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(up9)
+    conv9 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(conv9)
+
+    conv10 = Conv2D(num_classes, (1, 1), activation='sigmoid')(conv9)
+
+    model = Model(inputs=[inputs], outputs=[conv10])
+    
+    if not isinstance(metrics, list):
+        metrics = [metrics]
+
+    if include_label_wise_dice_coefficients and num_classes > 1:
+        label_wise_dice_metrics = [get_label_dice_coefficient_function(index) for index in range(num_classes)]
+        if metrics:
+            metrics = metrics + label_wise_dice_metrics
+        else:
+            metrics = label_wise_dice_metrics
+            
+    model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=metrics)
+
+    return model
+
+# 2D U-net depth=4
+def get_unet_reduced():
+    metrics = dice_coef
+    include_label_wise_dice_coefficients = True;
+    inputs = Input((patch_size, patch_size, 1))
+    conv1 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(inputs)
+    conv1 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(conv1)
+    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+    conv2 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(pool1)
+    conv2 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(conv2)
+    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+    conv3 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(pool2)
+    conv3 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(conv3)
+    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+    conv4 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(pool3)
+    conv4 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(conv4)
+
+
+    up7 = concatenate([Conv2DTranspose(BASE*4, (2, 2),strides=(2, 2), padding='same')(conv4), conv3], axis=3)
+    conv7 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(up7)
+    conv7 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(conv7)
+
+    up8 = concatenate([Conv2DTranspose(BASE*2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
+    conv8 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(up8)
+    conv8 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(conv8)
+
+    up9 = concatenate([Conv2DTranspose(BASE, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
+    conv9 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(up9)
+    conv9 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(conv9)
+
+    conv10 = Conv2D(num_classes, (1, 1), activation='sigmoid')(conv9)
+
+    model = Model(inputs=[inputs], outputs=[conv10])
+    #model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
+    
+    if not isinstance(metrics, list):
+        metrics = [metrics]
+
+    if include_label_wise_dice_coefficients and num_classes > 1:
+        label_wise_dice_metrics = [get_label_dice_coefficient_function(index) for index in range(num_classes)]
+        if metrics:
+            metrics = metrics + label_wise_dice_metrics
+        else:
+            metrics = label_wise_dice_metrics
+
+    model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=metrics)
+    return model
+
+# 2D U-net depth=6
+def get_unet_extended():
+    metrics = dice_coef
+    include_label_wise_dice_coefficients = True;
+    inputs = Input((patch_size, patch_size, 1))
+    conv1 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(inputs)
+    conv1 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(conv1)
+    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+    conv2 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(pool1)
+    conv2 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(conv2)
+    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+    conv3 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(pool2)
+    conv3 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(conv3)
+    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+    conv4 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(pool3)
+    conv4 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(conv4)
+    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
+
+    conv5 = Conv2D(BASE*16, (3, 3), activation='relu', padding='same')(pool4)
+    conv5 = Conv2D(BASE*16, (3, 3), activation='relu', padding='same')(conv5)
+    pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)
+    
+    conv5_extend = Conv2D(BASE*32, (3, 3), activation='relu', padding='same')(pool5)
+    conv5_extend = Conv2D(BASE*32, (3, 3), activation='relu', padding='same')(conv5_extend)
+
+    up6_extend = concatenate([Conv2DTranspose(BASE*16, (2, 2), strides=(2, 2), padding='same')(conv5_extend), conv5], axis=3)
+    conv6_extend = Conv2D(BASE*16, (3, 3), activation='relu', padding='same')(up6_extend)
+    conv6_extend = Conv2D(BASE*16, (3, 3), activation='relu', padding='same')(conv6_extend)
+    
+    up6 = concatenate([Conv2DTranspose(BASE*8, (2, 2), strides=(2, 2), padding='same')(conv6_extend), conv4], axis=3)
+    conv6 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(up6)
+    conv6 = Conv2D(BASE*8, (3, 3), activation='relu', padding='same')(conv6)
+
+    up7 = concatenate([Conv2DTranspose(BASE*4, (2, 2),strides=(2, 2), padding='same')(conv6), conv3], axis=3)
+    conv7 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(up7)
+    conv7 = Conv2D(BASE*4, (3, 3), activation='relu', padding='same')(conv7)
+
+    up8 = concatenate([Conv2DTranspose(BASE*2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
+    conv8 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(up8)
+    conv8 = Conv2D(BASE*2, (3, 3), activation='relu', padding='same')(conv8)
+
+    up9 = concatenate([Conv2DTranspose(BASE, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
+    conv9 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(up9)
+    conv9 = Conv2D(BASE, (3, 3), activation='relu', padding='same')(conv9)
+
+    conv10 = Conv2D(num_classes, (1, 1), activation='sigmoid')(conv9)
+
+    model = Model(inputs=[inputs], outputs=[conv10])
+    
+    if not isinstance(metrics, list):
+        metrics = [metrics]
+
+    if include_label_wise_dice_coefficients and num_classes > 1:
+        label_wise_dice_metrics = [get_label_dice_coefficient_function(index) for index in range(num_classes)]
+        if metrics:
+            metrics = metrics + label_wise_dice_metrics
+        else:
+            metrics = label_wise_dice_metrics
+            
+    model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=metrics)
+    return model
+
+
+# train
+def train():
+    print('-'*30)
+    print('Loading and preprocessing train data...')
+    print('-'*30)
+    imgs_train, imgs_gtruth_train = load_train_data()
+    
+    print('-'*30)
+    print('Loading and preprocessing validation data...')
+    print('-'*30)   
+    imgs_val, imgs_gtruth_val  = load_validatation_data()
+      
+    print('-'*30)
+    print('Creating and compiling model...')
+    print('-'*30)
+    
+    if unet_model_type == 'default':
+        model = get_unet_default()
+    elif unet_model_type == 'reduced':
+        model = get_unet_reduced()
+    elif unet_model_type == 'extended':
+        model = get_unet_extended()  
+        
+    model.summary()        
+        
+    print('-'*30)
+    print('Fitting model...')
+    print('-'*30)
+    #============================================================================
+    print('training starting..')
+    log_filename = 'outputs/' + image_type +'_model_train.csv' 
+    #Callback that streams epoch results to a csv file.
+    
+    csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True)
+    
+    early_stopping = callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=PATIENCE, verbose=0, mode='min')
+    
+    #checkpoint_filepath = 'outputs/' + image_type +"_best_weight_model_{epoch:03d}_{val_loss:.4f}.hdf5"
+    checkpoint_filepath = 'outputs/' + 'weights.h5'
+    
+    checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
+    
+    #callbacks_list = [csv_log, checkpoint]
+    callbacks_list = [csv_log, early_stopping, checkpoint]
+
+    #============================================================================
+    hist = model.fit(imgs_train, imgs_gtruth_train, batch_size=batch_size, nb_epoch=nb_epochs, verbose=1, validation_data=(imgs_val,imgs_gtruth_val), shuffle=True, callbacks=callbacks_list) #              validation_split=0.2,
+             
+    model_name = 'outputs/' + image_type + '_model_last'
+    model.save(model_name)  # creates a HDF5 file 'my_model.h5'
+
+	
+# main
+if __name__ == '__main__':
+    # folder to hold outputs
+    if 'outputs' not in os.listdir(os.curdir):
+        os.mkdir('outputs')   
+    train()