|
a |
|
b/UNET/Code/LUNA_unet.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import keras |
|
|
5 |
from keras.models import Model |
|
|
6 |
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D |
|
|
7 |
from keras.optimizers import Adam |
|
|
8 |
from keras.optimizers import SGD |
|
|
9 |
from keras.callbacks import ModelCheckpoint, LearningRateScheduler |
|
|
10 |
from keras import backend as K |
|
|
11 |
from keras.layers import Dropout |
|
|
12 |
|
|
|
13 |
from sklearn.externals import joblib |
|
|
14 |
import argparse |
|
|
15 |
from keras.callbacks import * |
|
|
16 |
import sys |
|
|
17 |
import theano |
|
|
18 |
import theano.tensor as T |
|
|
19 |
from keras import initializations |
|
|
20 |
from keras.layers import BatchNormalization |
|
|
21 |
import copy |
|
|
22 |
K.set_image_dim_ordering('th') # Theano dimension ordering in this code |
|
|
23 |
|
|
|
24 |
''' |
|
|
25 |
DEFAULT CONFIGURATIONS |
|
|
26 |
''' |
|
|
27 |
def get_options(): |
|
|
28 |
|
|
|
29 |
parser = argparse.ArgumentParser(description='UNET for Lung Nodule Detection') |
|
|
30 |
|
|
|
31 |
parser.add_argument('-out_dir', action="store", default='/scratch/cse/dual/cs5130287/Luna2016/output_final/', |
|
|
32 |
dest="out_dir", type=str) |
|
|
33 |
|
|
|
34 |
parser.add_argument('-epochs', action="store", default=500, dest="epochs", type=int) |
|
|
35 |
|
|
|
36 |
parser.add_argument('-batch_size', action="store", default=2, dest="batch_size", type=int) |
|
|
37 |
|
|
|
38 |
parser.add_argument('-lr', action="store", default=0.001, dest="lr", type=float) |
|
|
39 |
parser.add_argument('-load_weights', action="store", default=False, dest="load_weights", type=bool) |
|
|
40 |
parser.add_argument('-filter_width', action="store", default=3, dest="filter_width",type=int) |
|
|
41 |
parser.add_argument('-stride', action="store", default=3, dest="stride",type=int) |
|
|
42 |
parser.add_argument('-model_file', action="store", default="", dest="model_file",type=str) #TODO |
|
|
43 |
parser.add_argument('-save_prefix', action="store", default="model_", |
|
|
44 |
dest="save_prefix",type=str) |
|
|
45 |
opts = parser.parse_args(sys.argv[1:]) |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
return opts |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def dice_coef(y_true,y_pred): |
|
|
53 |
y_true = K.flatten(y_true) |
|
|
54 |
y_pred = K.flatten(y_pred) |
|
|
55 |
smooth = 0. |
|
|
56 |
intersection = K.sum(y_true*y_pred) |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
return (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth) |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
|
|
|
63 |
def dice_coef_loss(y_true, y_pred): |
|
|
64 |
return 1. - dice_coef(y_true, y_pred) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def gaussian_init(shape, name=None, dim_ordering=None): |
|
|
68 |
return initializations.normal(shape, scale=0.001, name=name, dim_ordering=dim_ordering) |
|
|
69 |
|
|
|
70 |
def get_unet_small(options): |
|
|
71 |
inputs = Input((1, 512, 512)) |
|
|
72 |
conv1 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same')(inputs) |
|
|
73 |
conv1 = Dropout(0.2)(conv1) |
|
|
74 |
conv1 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_1')(conv1) |
|
|
75 |
pool1 = MaxPooling2D(pool_size=(2, 2), name='pool_1')(conv1) |
|
|
76 |
pool1 = BatchNormalization()(pool1) |
|
|
77 |
|
|
|
78 |
conv2 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same')(pool1) |
|
|
79 |
conv2 = Dropout(0.2)(conv2) |
|
|
80 |
conv2 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_2')(conv2) |
|
|
81 |
pool2 = MaxPooling2D(pool_size=(2, 2), name='pool_2')(conv2) |
|
|
82 |
pool2 = BatchNormalization()(pool2) |
|
|
83 |
|
|
|
84 |
conv3 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same')(pool2) |
|
|
85 |
conv3 = Dropout(0.2)(conv3) |
|
|
86 |
conv3 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_3')(conv3) |
|
|
87 |
pool3 = MaxPooling2D(pool_size=(2, 2), name='pool_3')(conv3) |
|
|
88 |
pool3 = BatchNormalization()(pool3) |
|
|
89 |
|
|
|
90 |
conv4 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same')(pool3) |
|
|
91 |
conv4 = Dropout(0.2)(conv4) |
|
|
92 |
conv4 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_4')(conv4) |
|
|
93 |
conv4 = BatchNormalization()(conv4) |
|
|
94 |
# pool4 = MaxPooling2D(pool_size=(2, 2), name='pool_4')(conv4) |
|
|
95 |
|
|
|
96 |
# conv5 = Convolution2D(512, options.filter_width, options.stride, activation='elu',border_mode='same')(pool4) |
|
|
97 |
# conv5 = Dropout(0.2)(conv5) |
|
|
98 |
# conv5 = Convolution2D(512, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_5')(conv5) |
|
|
99 |
|
|
|
100 |
# up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1) |
|
|
101 |
# conv6 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same')(up6) |
|
|
102 |
# conv6 = Dropout(0.2)(conv6) |
|
|
103 |
# conv6 = Convolution2D(256, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_6')(conv6) |
|
|
104 |
|
|
|
105 |
up7 = merge([UpSampling2D(size=(2, 2))(conv4), conv3], mode='concat', concat_axis=1) |
|
|
106 |
|
|
|
107 |
conv7 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same')(up7) |
|
|
108 |
conv7 = Dropout(0.2)(conv7) |
|
|
109 |
conv7 = Convolution2D(128, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_7')(conv7) |
|
|
110 |
conv7 = BatchNormalization()(conv7) |
|
|
111 |
|
|
|
112 |
up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1) |
|
|
113 |
conv8 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same')(up8) |
|
|
114 |
conv8 = Dropout(0.2)(conv8) |
|
|
115 |
conv8 = Convolution2D(64, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_8')(conv8) |
|
|
116 |
conv8 = BatchNormalization()(conv8) |
|
|
117 |
|
|
|
118 |
up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1) |
|
|
119 |
conv9 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same')(up9) |
|
|
120 |
conv9 = Dropout(0.2)(conv9) |
|
|
121 |
conv9 = Convolution2D(32, options.filter_width, options.stride, activation='elu',border_mode='same', name='conv_9')(conv9) |
|
|
122 |
conv9 = BatchNormalization()(conv9) |
|
|
123 |
|
|
|
124 |
conv10 = Convolution2D(1, 1, 1, activation='sigmoid', name='sigmoid')(conv9) |
|
|
125 |
|
|
|
126 |
model = Model(input=inputs, output=conv10) |
|
|
127 |
model.summary() |
|
|
128 |
model.compile(optimizer=Adam(lr=options.lr, clipvalue=1., clipnorm=1.), loss=dice_coef_loss, metrics=[dice_coef]) |
|
|
129 |
|
|
|
130 |
return model |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
|
|
|
134 |
class WeightSave(Callback): |
|
|
135 |
def __init__(self, options): |
|
|
136 |
self.options = options |
|
|
137 |
|
|
|
138 |
def on_train_begin(self, logs={}): |
|
|
139 |
if self.options.load_weights: |
|
|
140 |
print('LOADING WEIGHTS FROM : ' + self.options.model_file) |
|
|
141 |
weights = joblib.load( self.options.model_file ) |
|
|
142 |
self.model.set_weights(weights) |
|
|
143 |
def on_epoch_end(self, epochs, logs = {}): |
|
|
144 |
cur_weights = self.model.get_weights() |
|
|
145 |
joblib.dump(cur_weights, self.options.save_prefix + '_script_on_epoch_' + str(epochs) + '_lr_' + str(self.options.lr) + '_WITH_STRIDES_' + str(self.options.stride) +'_FILTER_WIDTH_' + str(self.options.filter_width) + '.weights') |
|
|
146 |
|
|
|
147 |
class Accuracy(Callback): |
|
|
148 |
def __init__(self,test_data_x,test_data_y): |
|
|
149 |
self.test_data_x=test_data_x |
|
|
150 |
self.test_data_y=test_data_y |
|
|
151 |
test = T.tensor4('test') |
|
|
152 |
pred = T.tensor4('pred') |
|
|
153 |
dc = dice_coef(test,pred) |
|
|
154 |
self.dc = theano.function([test,pred],dc) |
|
|
155 |
|
|
|
156 |
def on_epoch_end(self,epochs, logs = {}): |
|
|
157 |
predicted = self.model.predict(self.test_data_x) |
|
|
158 |
print ("Validation : %f"%self.dc(self.test_data_y,predicted)) |
|
|
159 |
|
|
|
160 |
def train(use_existing): |
|
|
161 |
print ("Loading the options ....") |
|
|
162 |
options = get_options() |
|
|
163 |
print ("epochs: %d"%options.epochs) |
|
|
164 |
print ("batch_size: %d"%options.batch_size) |
|
|
165 |
print ("filter_width: %d"%options.filter_width) |
|
|
166 |
print ("stride: %d"%options.stride) |
|
|
167 |
print ("learning rate: %f"%options.lr) |
|
|
168 |
sys.stdout.flush() |
|
|
169 |
|
|
|
170 |
print('-'*30) |
|
|
171 |
print('Loading and preprocessing train data...') |
|
|
172 |
print('-'*30) |
|
|
173 |
imgs_train = np.load(options.out_dir+"trainImages.npy").astype(np.float32) |
|
|
174 |
imgs_mask_train = np.load(options.out_dir+"trainMasks.npy").astype(np.float32) |
|
|
175 |
|
|
|
176 |
# Renormalizing the masks |
|
|
177 |
imgs_mask_train[imgs_mask_train > 0.] = 1.0 |
|
|
178 |
|
|
|
179 |
# Now the Test Data |
|
|
180 |
imgs_test = np.load(options.out_dir+"testImages.npy").astype(np.float32) |
|
|
181 |
imgs_mask_test_true = np.load(options.out_dir+"testMasks.npy").astype(np.float32) |
|
|
182 |
# Renormalizing the test masks |
|
|
183 |
imgs_mask_test_true[imgs_mask_test_true > 0] = 1.0 |
|
|
184 |
|
|
|
185 |
print('-'*30) |
|
|
186 |
print('Creating and compiling model...') |
|
|
187 |
print('-'*30) |
|
|
188 |
model = get_unet_small(options) |
|
|
189 |
weight_save = WeightSave(options) |
|
|
190 |
accuracy = Accuracy(copy.deepcopy(imgs_test),copy.deepcopy(imgs_mask_test_true)) |
|
|
191 |
print('-'*30) |
|
|
192 |
print('Fitting model...') |
|
|
193 |
print('-'*30) |
|
|
194 |
model.fit(x=imgs_train, y=imgs_mask_train, batch_size=options.batch_size, nb_epoch=options.epochs, verbose=1, shuffle=True |
|
|
195 |
,callbacks=[weight_save, accuracy]) |
|
|
196 |
# callbacks = [accuracy]) |
|
|
197 |
# callbacks=[weight_save,accuracy]) |
|
|
198 |
return model |
|
|
199 |
|
|
|
200 |
if __name__ == '__main__': |
|
|
201 |
# print "epochs" |
|
|
202 |
model = train(False) |