Diff of /UNET/Code/LUNA_unet.py [000000] .. [8a73bc]

Switch to side-by-side view

--- a
+++ b/UNET/Code/LUNA_unet.py
@@ -0,0 +1,202 @@
+from __future__ import print_function
+
+import numpy as np
+import keras
+from keras.models import Model
+from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D
+from keras.optimizers import Adam
+from keras.optimizers import SGD
+from keras.callbacks import ModelCheckpoint, LearningRateScheduler
+from keras import backend as K
+from keras.layers import Dropout
+
+from sklearn.externals import joblib
+import argparse
+from keras.callbacks import *
+import sys
+import theano
+import theano.tensor as T
+from keras import initializations
+from keras.layers import BatchNormalization
+import copy
+K.set_image_dim_ordering('th')  # Theano dimension ordering in this code
+
+'''
+    DEFAULT CONFIGURATIONS
+'''
+def get_options():   
+    
+    parser = argparse.ArgumentParser(description='UNET for Lung Nodule Detection')
+    
+    parser.add_argument('-out_dir', action="store", default='/scratch/cse/dual/cs5130287/Luna2016/output_final/',
+                        dest="out_dir", type=str)
+    
+    parser.add_argument('-epochs', action="store", default=500, dest="epochs", type=int)
+    
+    parser.add_argument('-batch_size', action="store", default=2, dest="batch_size", type=int)    
+    
+    parser.add_argument('-lr', action="store", default=0.001, dest="lr", type=float)
+    parser.add_argument('-load_weights', action="store", default=False, dest="load_weights", type=bool)
+    parser.add_argument('-filter_width', action="store", default=3, dest="filter_width",type=int)
+    parser.add_argument('-stride', action="store", default=3, dest="stride",type=int)
+    parser.add_argument('-model_file', action="store", default="", dest="model_file",type=str) #TODO
+    parser.add_argument('-save_prefix', action="store", default="model_",
+                        dest="save_prefix",type=str)
+    opts = parser.parse_args(sys.argv[1:])    
+        
+
+    return opts
+
+
+
+def dice_coef(y_true,y_pred):
+    y_true = K.flatten(y_true)
+    y_pred = K.flatten(y_pred)
+    smooth = 0.
+    intersection = K.sum(y_true*y_pred)
+    
+    
+    return (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth)
+
+
+
+def dice_coef_loss(y_true, y_pred):
+    return 1. - dice_coef(y_true, y_pred)
+
+
+def gaussian_init(shape, name=None, dim_ordering=None):
+   return initializations.normal(shape, scale=0.001, name=name, dim_ordering=dim_ordering)
+
+def get_unet_small(options):
+    inputs = Input((1, 512, 512))
+    conv1 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same')(inputs)
+    conv1 = Dropout(0.2)(conv1)
+    conv1 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_1')(conv1)
+    pool1 = MaxPooling2D(pool_size=(2, 2), name='pool_1')(conv1)
+    pool1 = BatchNormalization()(pool1)
+
+    conv2 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same')(pool1)
+    conv2 = Dropout(0.2)(conv2)
+    conv2 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_2')(conv2)
+    pool2 = MaxPooling2D(pool_size=(2, 2), name='pool_2')(conv2)
+    pool2 = BatchNormalization()(pool2)
+
+    conv3 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same')(pool2)
+    conv3 = Dropout(0.2)(conv3)
+    conv3 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_3')(conv3)
+    pool3 = MaxPooling2D(pool_size=(2, 2), name='pool_3')(conv3)
+    pool3 = BatchNormalization()(pool3)
+
+    conv4 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same')(pool3)
+    conv4 = Dropout(0.2)(conv4)
+    conv4 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_4')(conv4)
+    conv4 = BatchNormalization()(conv4)
+    # pool4 = MaxPooling2D(pool_size=(2, 2), name='pool_4')(conv4)
+
+    # conv5 = Convolution2D(512, options.filter_width, options.stride, activation='elu',border_mode='same')(pool4)
+    # conv5 = Dropout(0.2)(conv5)
+    # conv5 = Convolution2D(512, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_5')(conv5)
+
+    # up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
+    # conv6 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same')(up6)
+    # conv6 = Dropout(0.2)(conv6)
+    # conv6 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_6')(conv6)
+
+    up7 = merge([UpSampling2D(size=(2, 2))(conv4), conv3], mode='concat', concat_axis=1)
+
+    conv7 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same')(up7)
+    conv7 = Dropout(0.2)(conv7)
+    conv7 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_7')(conv7)
+    conv7 = BatchNormalization()(conv7)
+
+    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
+    conv8 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same')(up8)
+    conv8 = Dropout(0.2)(conv8)
+    conv8 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_8')(conv8)
+    conv8 = BatchNormalization()(conv8)
+
+    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
+    conv9 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same')(up9)
+    conv9 = Dropout(0.2)(conv9)
+    conv9 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_9')(conv9)
+    conv9 = BatchNormalization()(conv9)
+
+    conv10 = Convolution2D(1, 1, 1, activation='sigmoid', name='sigmoid')(conv9)
+
+    model = Model(input=inputs, output=conv10)
+    model.summary()
+    model.compile(optimizer=Adam(lr=options.lr, clipvalue=1., clipnorm=1.), loss=dice_coef_loss, metrics=[dice_coef])
+
+    return model
+
+
+
+class WeightSave(Callback):
+    def __init__(self, options):
+        self.options = options
+
+    def on_train_begin(self, logs={}):
+        if self.options.load_weights:            
+            print('LOADING WEIGHTS FROM : ' + self.options.model_file)
+            weights = joblib.load( self.options.model_file )
+            self.model.set_weights(weights)
+    def on_epoch_end(self, epochs, logs = {}):
+        cur_weights = self.model.get_weights()
+        joblib.dump(cur_weights, self.options.save_prefix + '_script_on_epoch_' + str(epochs) + '_lr_' + str(self.options.lr) + '_WITH_STRIDES_' + str(self.options.stride) +'_FILTER_WIDTH_' + str(self.options.filter_width) + '.weights')
+
+class Accuracy(Callback):
+    def __init__(self,test_data_x,test_data_y):
+        self.test_data_x=test_data_x
+        self.test_data_y=test_data_y
+        test = T.tensor4('test')
+        pred = T.tensor4('pred')
+        dc = dice_coef(test,pred)
+        self.dc = theano.function([test,pred],dc)
+
+    def on_epoch_end(self,epochs, logs = {}):
+        predicted = self.model.predict(self.test_data_x)
+        print ("Validation : %f"%self.dc(self.test_data_y,predicted))
+
+def train(use_existing):
+    print ("Loading the options ....")
+    options = get_options()
+    print ("epochs: %d"%options.epochs)
+    print ("batch_size: %d"%options.batch_size)
+    print ("filter_width: %d"%options.filter_width)
+    print ("stride: %d"%options.stride)
+    print ("learning rate: %f"%options.lr)
+    sys.stdout.flush()
+
+    print('-'*30)
+    print('Loading and preprocessing train data...')
+    print('-'*30)
+    imgs_train = np.load(options.out_dir+"trainImages.npy").astype(np.float32)
+    imgs_mask_train = np.load(options.out_dir+"trainMasks.npy").astype(np.float32)
+
+    # Renormalizing the masks
+    imgs_mask_train[imgs_mask_train > 0.] = 1.0
+    
+    # Now the Test Data
+    imgs_test = np.load(options.out_dir+"testImages.npy").astype(np.float32)
+    imgs_mask_test_true = np.load(options.out_dir+"testMasks.npy").astype(np.float32)
+    # Renormalizing the test masks
+    imgs_mask_test_true[imgs_mask_test_true > 0] = 1.0    
+
+    print('-'*30)
+    print('Creating and compiling model...')
+    print('-'*30)
+    model = get_unet_small(options)
+    weight_save = WeightSave(options)
+    accuracy = Accuracy(copy.deepcopy(imgs_test),copy.deepcopy(imgs_mask_test_true))
+    print('-'*30)
+    print('Fitting model...')
+    print('-'*30)
+    model.fit(x=imgs_train, y=imgs_mask_train, batch_size=options.batch_size, nb_epoch=options.epochs, verbose=1, shuffle=True
+            ,callbacks=[weight_save, accuracy])
+              # callbacks = [accuracy])
+              # callbacks=[weight_save,accuracy])
+    return model
+
+if __name__ == '__main__':
+    # print "epochs"
+    model = train(False)