a b/UNET/Code/LUNA_unet.py
1
from __future__ import print_function
2
3
import numpy as np
4
import keras
5
from keras.models import Model
6
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D
7
from keras.optimizers import Adam
8
from keras.optimizers import SGD
9
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
10
from keras import backend as K
11
from keras.layers import Dropout
12
13
from sklearn.externals import joblib
14
import argparse
15
from keras.callbacks import *
16
import sys
17
import theano
18
import theano.tensor as T
19
from keras import initializations
20
from keras.layers import BatchNormalization
21
import copy
22
K.set_image_dim_ordering('th')  # Theano dimension ordering in this code
23
24
'''
25
    DEFAULT CONFIGURATIONS
26
'''
27
def get_options():   
28
    
29
    parser = argparse.ArgumentParser(description='UNET for Lung Nodule Detection')
30
    
31
    parser.add_argument('-out_dir', action="store", default='/scratch/cse/dual/cs5130287/Luna2016/output_final/',
32
                        dest="out_dir", type=str)
33
    
34
    parser.add_argument('-epochs', action="store", default=500, dest="epochs", type=int)
35
    
36
    parser.add_argument('-batch_size', action="store", default=2, dest="batch_size", type=int)    
37
    
38
    parser.add_argument('-lr', action="store", default=0.001, dest="lr", type=float)
39
    parser.add_argument('-load_weights', action="store", default=False, dest="load_weights", type=bool)
40
    parser.add_argument('-filter_width', action="store", default=3, dest="filter_width",type=int)
41
    parser.add_argument('-stride', action="store", default=3, dest="stride",type=int)
42
    parser.add_argument('-model_file', action="store", default="", dest="model_file",type=str) #TODO
43
    parser.add_argument('-save_prefix', action="store", default="model_",
44
                        dest="save_prefix",type=str)
45
    opts = parser.parse_args(sys.argv[1:])    
46
        
47
48
    return opts
49
50
51
52
def dice_coef(y_true,y_pred):
53
    y_true = K.flatten(y_true)
54
    y_pred = K.flatten(y_pred)
55
    smooth = 0.
56
    intersection = K.sum(y_true*y_pred)
57
    
58
    
59
    return (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth)
60
61
62
63
def dice_coef_loss(y_true, y_pred):
64
    return 1. - dice_coef(y_true, y_pred)
65
66
67
def gaussian_init(shape, name=None, dim_ordering=None):
68
   return initializations.normal(shape, scale=0.001, name=name, dim_ordering=dim_ordering)
69
70
def get_unet_small(options):
71
    inputs = Input((1, 512, 512))
72
    conv1 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same')(inputs)
73
    conv1 = Dropout(0.2)(conv1)
74
    conv1 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_1')(conv1)
75
    pool1 = MaxPooling2D(pool_size=(2, 2), name='pool_1')(conv1)
76
    pool1 = BatchNormalization()(pool1)
77
78
    conv2 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same')(pool1)
79
    conv2 = Dropout(0.2)(conv2)
80
    conv2 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_2')(conv2)
81
    pool2 = MaxPooling2D(pool_size=(2, 2), name='pool_2')(conv2)
82
    pool2 = BatchNormalization()(pool2)
83
84
    conv3 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same')(pool2)
85
    conv3 = Dropout(0.2)(conv3)
86
    conv3 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_3')(conv3)
87
    pool3 = MaxPooling2D(pool_size=(2, 2), name='pool_3')(conv3)
88
    pool3 = BatchNormalization()(pool3)
89
90
    conv4 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same')(pool3)
91
    conv4 = Dropout(0.2)(conv4)
92
    conv4 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_4')(conv4)
93
    conv4 = BatchNormalization()(conv4)
94
    # pool4 = MaxPooling2D(pool_size=(2, 2), name='pool_4')(conv4)
95
96
    # conv5 = Convolution2D(512, options.filter_width, options.stride, activation='elu',border_mode='same')(pool4)
97
    # conv5 = Dropout(0.2)(conv5)
98
    # conv5 = Convolution2D(512, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_5')(conv5)
99
100
    # up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
101
    # conv6 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same')(up6)
102
    # conv6 = Dropout(0.2)(conv6)
103
    # conv6 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_6')(conv6)
104
105
    up7 = merge([UpSampling2D(size=(2, 2))(conv4), conv3], mode='concat', concat_axis=1)
106
107
    conv7 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same')(up7)
108
    conv7 = Dropout(0.2)(conv7)
109
    conv7 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_7')(conv7)
110
    conv7 = BatchNormalization()(conv7)
111
112
    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
113
    conv8 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same')(up8)
114
    conv8 = Dropout(0.2)(conv8)
115
    conv8 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_8')(conv8)
116
    conv8 = BatchNormalization()(conv8)
117
118
    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
119
    conv9 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same')(up9)
120
    conv9 = Dropout(0.2)(conv9)
121
    conv9 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_9')(conv9)
122
    conv9 = BatchNormalization()(conv9)
123
124
    conv10 = Convolution2D(1, 1, 1, activation='sigmoid', name='sigmoid')(conv9)
125
126
    model = Model(input=inputs, output=conv10)
127
    model.summary()
128
    model.compile(optimizer=Adam(lr=options.lr, clipvalue=1., clipnorm=1.), loss=dice_coef_loss, metrics=[dice_coef])
129
130
    return model
131
132
133
134
class WeightSave(Callback):
135
    def __init__(self, options):
136
        self.options = options
137
138
    def on_train_begin(self, logs={}):
139
        if self.options.load_weights:            
140
            print('LOADING WEIGHTS FROM : ' + self.options.model_file)
141
            weights = joblib.load( self.options.model_file )
142
            self.model.set_weights(weights)
143
    def on_epoch_end(self, epochs, logs = {}):
144
        cur_weights = self.model.get_weights()
145
        joblib.dump(cur_weights, self.options.save_prefix + '_script_on_epoch_' + str(epochs) + '_lr_' + str(self.options.lr) + '_WITH_STRIDES_' + str(self.options.stride) +'_FILTER_WIDTH_' + str(self.options.filter_width) + '.weights')
146
147
class Accuracy(Callback):
148
    def __init__(self,test_data_x,test_data_y):
149
        self.test_data_x=test_data_x
150
        self.test_data_y=test_data_y
151
        test = T.tensor4('test')
152
        pred = T.tensor4('pred')
153
        dc = dice_coef(test,pred)
154
        self.dc = theano.function([test,pred],dc)
155
156
    def on_epoch_end(self,epochs, logs = {}):
157
        predicted = self.model.predict(self.test_data_x)
158
        print ("Validation : %f"%self.dc(self.test_data_y,predicted))
159
160
def train(use_existing):
161
    print ("Loading the options ....")
162
    options = get_options()
163
    print ("epochs: %d"%options.epochs)
164
    print ("batch_size: %d"%options.batch_size)
165
    print ("filter_width: %d"%options.filter_width)
166
    print ("stride: %d"%options.stride)
167
    print ("learning rate: %f"%options.lr)
168
    sys.stdout.flush()
169
170
    print('-'*30)
171
    print('Loading and preprocessing train data...')
172
    print('-'*30)
173
    imgs_train = np.load(options.out_dir+"trainImages.npy").astype(np.float32)
174
    imgs_mask_train = np.load(options.out_dir+"trainMasks.npy").astype(np.float32)
175
176
    # Renormalizing the masks
177
    imgs_mask_train[imgs_mask_train > 0.] = 1.0
178
    
179
    # Now the Test Data
180
    imgs_test = np.load(options.out_dir+"testImages.npy").astype(np.float32)
181
    imgs_mask_test_true = np.load(options.out_dir+"testMasks.npy").astype(np.float32)
182
    # Renormalizing the test masks
183
    imgs_mask_test_true[imgs_mask_test_true > 0] = 1.0    
184
185
    print('-'*30)
186
    print('Creating and compiling model...')
187
    print('-'*30)
188
    model = get_unet_small(options)
189
    weight_save = WeightSave(options)
190
    accuracy = Accuracy(copy.deepcopy(imgs_test),copy.deepcopy(imgs_mask_test_true))
191
    print('-'*30)
192
    print('Fitting model...')
193
    print('-'*30)
194
    model.fit(x=imgs_train, y=imgs_mask_train, batch_size=options.batch_size, nb_epoch=options.epochs, verbose=1, shuffle=True
195
            ,callbacks=[weight_save, accuracy])
196
              # callbacks = [accuracy])
197
              # callbacks=[weight_save,accuracy])
198
    return model
199
200
if __name__ == '__main__':
201
    # print "epochs"
202
    model = train(False)