|
a |
|
b/train.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import DataGenerator |
|
|
3 |
import time |
|
|
4 |
|
|
|
5 |
learning_rate = 1e-5 |
|
|
6 |
batch_size = 2 |
|
|
7 |
prefetch = 4 |
|
|
8 |
no_of_epochs = 10000 |
|
|
9 |
smoothing = 0.00001 |
|
|
10 |
|
|
|
11 |
# placeholder for training mode |
|
|
12 |
is_training = tf.placeholder(tf.bool) |
|
|
13 |
|
|
|
14 |
# input data generator |
|
|
15 |
trainTransforms = [ |
|
|
16 |
DataGenerator.RandomFlip(), |
|
|
17 |
DataGenerator.HistogramMatching(data_dir='train-data', train_size=40, prob=0.5), |
|
|
18 |
DataGenerator.RandomSmoothing(prob=0.3), |
|
|
19 |
DataGenerator.RandomNoise(prob=0.5), |
|
|
20 |
DataGenerator.Normalization() |
|
|
21 |
] |
|
|
22 |
|
|
|
23 |
valTransforms = [ |
|
|
24 |
DataGenerator.Normalization() |
|
|
25 |
] |
|
|
26 |
|
|
|
27 |
TrainDataset = DataGenerator.DataGenerator( |
|
|
28 |
data_dir='train-data', |
|
|
29 |
transforms=trainTransforms, |
|
|
30 |
train=True |
|
|
31 |
) |
|
|
32 |
|
|
|
33 |
ValDataset = DataGenerator.DataGenerator( |
|
|
34 |
data_dir='val-data', |
|
|
35 |
transforms=valTransforms, |
|
|
36 |
train=False |
|
|
37 |
) |
|
|
38 |
|
|
|
39 |
trainDataset = TrainDataset.get_dataset() |
|
|
40 |
trainDataset = trainDataset.shuffle(buffer_size=5) |
|
|
41 |
trainDataset = trainDataset.batch(batch_size) |
|
|
42 |
trainDataset = trainDataset.prefetch(prefetch) |
|
|
43 |
|
|
|
44 |
valDataset = ValDataset.get_dataset() |
|
|
45 |
valDataset = valDataset.shuffle(buffer_size=5) |
|
|
46 |
valDataset = valDataset.batch(batch_size) |
|
|
47 |
valDataset = valDataset.prefetch(prefetch) |
|
|
48 |
|
|
|
49 |
iterator = tf.data.Iterator.from_structure(trainDataset.output_types, trainDataset.output_shapes) |
|
|
50 |
|
|
|
51 |
training_init_op = iterator.make_initializer(trainDataset) |
|
|
52 |
validation_init_op = iterator.make_initializer(valDataset) |
|
|
53 |
next_item = iterator.get_next() |
|
|
54 |
|
|
|
55 |
# convolution layer |
|
|
56 |
def conv3d(x, no_of_input_channels, no_of_filters, filter_size, strides, padding, name): |
|
|
57 |
with tf.variable_scope(name) as scope: |
|
|
58 |
|
|
|
59 |
initializer = tf.variance_scaling_initializer() |
|
|
60 |
|
|
|
61 |
filter_size.extend([no_of_input_channels, no_of_filters]) |
|
|
62 |
weights = tf.Variable(initializer(filter_size), name='weights') |
|
|
63 |
biases = tf.Variable(initializer([no_of_filters]), name='biases') |
|
|
64 |
conv = tf.nn.conv3d(x, weights, strides=strides, padding=padding, name=name) |
|
|
65 |
conv += biases |
|
|
66 |
|
|
|
67 |
return conv |
|
|
68 |
|
|
|
69 |
# transposed convolution layer |
|
|
70 |
def upsamp(x, no_of_kernels, name): |
|
|
71 |
with tf.variable_scope(name) as scope: |
|
|
72 |
upsamp = tf.layers.conv3d_transpose(x, no_of_kernels, [2,2,2], 2, padding='VALID', use_bias=True, reuse=tf.AUTO_REUSE) |
|
|
73 |
return upsamp |
|
|
74 |
|
|
|
75 |
# PReLu layer |
|
|
76 |
def prelu(x, scope=None): |
|
|
77 |
with tf.variable_scope(name_or_scope=scope, default_name="prelu", reuse=tf.AUTO_REUSE): |
|
|
78 |
alpha = tf.get_variable("prelu", shape=x.get_shape()[-1], dtype=x.dtype, initializer=tf.constant_initializer(0.1)) |
|
|
79 |
prelu_out = tf.maximum(0.0, x) + alpha * tf.minimum(0.0, x) |
|
|
80 |
return prelu_out |
|
|
81 |
|
|
|
82 |
# model graph |
|
|
83 |
def graph_encoder(x): |
|
|
84 |
|
|
|
85 |
fine_grained_features = {} |
|
|
86 |
|
|
|
87 |
conv1 = conv3d(x,1,16,[3,3,3],[1,1,1,1,1],'SAME','Conv1_1') |
|
|
88 |
conv1 = conv3d(conv1,16,16,[3,3,3],[1,1,1,1,1],'SAME','Conv1_2') |
|
|
89 |
conv1 = tf.layers.batch_normalization(conv1, training=is_training) |
|
|
90 |
conv1 = prelu(conv1,'prelu1') |
|
|
91 |
|
|
|
92 |
res1 = tf.add(x,conv1) |
|
|
93 |
fine_grained_features['res1'] = res1 |
|
|
94 |
|
|
|
95 |
down1 = conv3d(res1,16,32,[2,2,2],[1,2,2,2,1],'VALID','DownSampling1') |
|
|
96 |
down1 = prelu(down1,'down_prelu1') |
|
|
97 |
|
|
|
98 |
conv2 = conv3d(down1,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv2_1') |
|
|
99 |
conv2 = conv3d(conv2,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv2_2') |
|
|
100 |
conv2= tf.layers.batch_normalization(conv2, training=is_training) |
|
|
101 |
conv2 = prelu(conv2,'prelu2') |
|
|
102 |
|
|
|
103 |
conv3 = conv3d(conv2,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv3_1') |
|
|
104 |
conv3 = conv3d(conv3,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv3_2') |
|
|
105 |
conv3 = tf.layers.batch_normalization(conv3, training=is_training) |
|
|
106 |
conv3 = prelu(conv3,'prelu3') |
|
|
107 |
|
|
|
108 |
res2 = tf.add(down1,conv3) |
|
|
109 |
fine_grained_features['res2'] = res2 |
|
|
110 |
|
|
|
111 |
down2 = conv3d(res2,32,64,[2,2,2],[1,2,2,2,1],'VALID','DownSampling2') |
|
|
112 |
down2 = prelu(down2,'down_prelu2') |
|
|
113 |
|
|
|
114 |
conv4 = conv3d(down2,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv4_1') |
|
|
115 |
conv4 = conv3d(conv4,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv4_2') |
|
|
116 |
conv4 = tf.layers.batch_normalization(conv4, training=is_training) |
|
|
117 |
conv4 = prelu(conv4,'prelu4') |
|
|
118 |
|
|
|
119 |
conv5 = conv3d(conv4,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv5_1') |
|
|
120 |
conv5 = conv3d(conv5,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv5_2') |
|
|
121 |
conv5 = tf.layers.batch_normalization(conv5, training=is_training) |
|
|
122 |
conv5 = prelu(conv5,'prelu5') |
|
|
123 |
|
|
|
124 |
conv6 = conv3d(conv5,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv6_1') |
|
|
125 |
conv6 = conv3d(conv6,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv6_2') |
|
|
126 |
conv6 = tf.layers.batch_normalization(conv6, training=is_training) |
|
|
127 |
conv6 = prelu(conv6,'prelu6') |
|
|
128 |
|
|
|
129 |
res3 = tf.add(down2,conv6) |
|
|
130 |
fine_grained_features['res3'] = res3 |
|
|
131 |
|
|
|
132 |
down3 = conv3d(res3,64,128,[2,2,2],[1,2,2,2,1],'VALID','DownSampling3') |
|
|
133 |
down3 = prelu(down3,'down_prelu3') |
|
|
134 |
|
|
|
135 |
conv7 = conv3d(down3,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv7_1') |
|
|
136 |
conv7 = conv3d(conv7,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv7_2') |
|
|
137 |
conv7 = tf.layers.batch_normalization(conv7, training=is_training) |
|
|
138 |
conv7 = prelu(conv7,'prelu7') |
|
|
139 |
|
|
|
140 |
conv8 = conv3d(conv7,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv8_1') |
|
|
141 |
conv8 = conv3d(conv8,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv8_2') |
|
|
142 |
conv8 = tf.layers.batch_normalization(conv8, training=is_training) |
|
|
143 |
conv8 = prelu(conv8,'prelu8') |
|
|
144 |
|
|
|
145 |
conv9 = conv3d(conv8,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv9_1') |
|
|
146 |
conv9 = conv3d(conv9,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv9_2') |
|
|
147 |
conv9 = tf.layers.batch_normalization(conv9, training=is_training) |
|
|
148 |
conv9 = prelu(conv9,'prelu9') |
|
|
149 |
|
|
|
150 |
res4 = tf.add(down3,conv9) |
|
|
151 |
fine_grained_features['res4'] = res4 |
|
|
152 |
|
|
|
153 |
down4 = conv3d(res4,128,256,[2,2,2],[1,2,2,2,1],'VALID','DownSampling4') |
|
|
154 |
down4 = prelu(down4,'down_prelu4') |
|
|
155 |
|
|
|
156 |
conv10 = conv3d(down4,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv10_1') |
|
|
157 |
conv10 = conv3d(conv10,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv10_2') |
|
|
158 |
conv10 = tf.layers.batch_normalization(conv10, training=is_training) |
|
|
159 |
conv10 = prelu(conv10,'prelu10') |
|
|
160 |
|
|
|
161 |
conv11 = conv3d(conv10,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv11_1') |
|
|
162 |
conv11 = conv3d(conv11,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv11_2') |
|
|
163 |
conv11 = tf.layers.batch_normalization(conv11, training=is_training) |
|
|
164 |
conv11 = prelu(conv11,'prelu11') |
|
|
165 |
|
|
|
166 |
conv12 = conv3d(conv11,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv12_1') |
|
|
167 |
conv12 = conv3d(conv12,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv12_2') |
|
|
168 |
conv12 = tf.layers.batch_normalization(conv12, training=is_training) |
|
|
169 |
conv12 = prelu(conv12,'prelu12') |
|
|
170 |
|
|
|
171 |
res5 = tf.add(down4,conv12) |
|
|
172 |
fine_grained_features['res5'] = res5 |
|
|
173 |
|
|
|
174 |
return fine_grained_features |
|
|
175 |
|
|
|
176 |
def graph_decoder(features): |
|
|
177 |
|
|
|
178 |
inp = features['res5'] |
|
|
179 |
|
|
|
180 |
upsamp1 = upsamp(inp,128,'Upsampling1') |
|
|
181 |
upsamp1 = prelu(upsamp1,'prelu_upsamp1') |
|
|
182 |
|
|
|
183 |
concat1 = tf.concat([upsamp1,features['res4']],axis=4) |
|
|
184 |
|
|
|
185 |
conv13 = conv3d(concat1,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv13_1') |
|
|
186 |
conv13 = conv3d(conv13,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv13_2') |
|
|
187 |
conv13 = tf.layers.batch_normalization(conv13, training=is_training) |
|
|
188 |
conv13 = prelu(conv13,'prelu13') |
|
|
189 |
|
|
|
190 |
conv14 = conv3d(conv13,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv14_1') |
|
|
191 |
conv14 = conv3d(conv14,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv14_2') |
|
|
192 |
conv14 = tf.layers.batch_normalization(conv14, training=is_training) |
|
|
193 |
conv14 = prelu(conv14,'prelu14') |
|
|
194 |
|
|
|
195 |
conv15 = conv3d(conv14,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv15_1') |
|
|
196 |
conv15 = conv3d(conv15,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv15_2') |
|
|
197 |
conv15 = tf.layers.batch_normalization(conv15, training=is_training) |
|
|
198 |
conv15 = prelu(conv15,'prelu15') |
|
|
199 |
|
|
|
200 |
res6 = tf.add(concat1,conv15) |
|
|
201 |
|
|
|
202 |
upsamp2 = upsamp(res6,64,'Upsampling2') |
|
|
203 |
upsamp2 = prelu(upsamp2,'prelu_upsamp2') |
|
|
204 |
|
|
|
205 |
concat2 = tf.concat([upsamp2,features['res3']],axis=4) |
|
|
206 |
|
|
|
207 |
conv16 = conv3d(concat2,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv16_1') |
|
|
208 |
conv16 = conv3d(conv16,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv16_2') |
|
|
209 |
conv16 = tf.layers.batch_normalization(conv16, training=is_training) |
|
|
210 |
conv16 = prelu(conv16,'prelu16') |
|
|
211 |
|
|
|
212 |
conv17 = conv3d(conv16,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv17_1') |
|
|
213 |
conv17 = conv3d(conv17,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv17_2') |
|
|
214 |
conv17 = tf.layers.batch_normalization(conv17, training=is_training) |
|
|
215 |
conv17 = prelu(conv17,'prelu17') |
|
|
216 |
|
|
|
217 |
conv18 = conv3d(conv17,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv18_1') |
|
|
218 |
conv18 = conv3d(conv18,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv18_2') |
|
|
219 |
conv18 = tf.layers.batch_normalization(conv18, training=is_training) |
|
|
220 |
conv18 = prelu(conv18,'prelu18') |
|
|
221 |
|
|
|
222 |
res7 = tf.add(concat2,conv18) |
|
|
223 |
|
|
|
224 |
upsamp3 = upsamp(res7,32,'Upsampling3') |
|
|
225 |
upsamp3 = prelu(upsamp3,'prelu_upsamp3') |
|
|
226 |
|
|
|
227 |
concat3 = tf.concat([upsamp3,features['res2']],axis=4) |
|
|
228 |
|
|
|
229 |
conv19 = conv3d(concat3,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv19_1') |
|
|
230 |
conv19 = conv3d(conv19,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv19_2') |
|
|
231 |
conv19 = tf.layers.batch_normalization(conv19, training=is_training) |
|
|
232 |
conv19 = prelu(conv19,'prelu19') |
|
|
233 |
|
|
|
234 |
conv20 = conv3d(conv19,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv20_1') |
|
|
235 |
conv20 = conv3d(conv20,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv20_2') |
|
|
236 |
conv20 = tf.layers.batch_normalization(conv20, training=is_training) |
|
|
237 |
conv20 = prelu(conv20,'prelu20') |
|
|
238 |
|
|
|
239 |
res8 = tf.add(concat3,conv20) |
|
|
240 |
|
|
|
241 |
upsamp4 = upsamp(res8,16,'Upsampling4') |
|
|
242 |
upsamp4 = prelu(upsamp4,'prelu_upsamp4') |
|
|
243 |
|
|
|
244 |
concat4 = tf.concat([upsamp4,features['res1']],axis=4) |
|
|
245 |
|
|
|
246 |
conv21 = conv3d(concat4,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv21_1') |
|
|
247 |
conv21 = conv3d(conv21,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv21_2') |
|
|
248 |
conv21 = tf.layers.batch_normalization(conv21, training=is_training) |
|
|
249 |
conv21 = prelu(conv21,'prelu21') |
|
|
250 |
|
|
|
251 |
res9 = tf.add(concat4,conv21) |
|
|
252 |
|
|
|
253 |
conv22 = conv3d(res9,32,1,[1,1,1],[1,1,1,1,1],'SAME','Conv22') |
|
|
254 |
conv22 = tf.nn.sigmoid(conv22,'sigmoid') |
|
|
255 |
|
|
|
256 |
return conv22 |
|
|
257 |
|
|
|
258 |
# loss and optimizer |
|
|
259 |
def model_fn(): |
|
|
260 |
|
|
|
261 |
features, labels = next_item |
|
|
262 |
|
|
|
263 |
features = tf.reshape(features, [-1, 128, 128, 64, 1]) |
|
|
264 |
labels = tf.cast(tf.reshape(labels, [-1, 128, 128, 64, 1]), dtype=tf.float32) |
|
|
265 |
|
|
|
266 |
# writing summaries to tensorboard |
|
|
267 |
tf.summary.image('features', features[:, :, :, 32:33, 0], max_outputs=2,collections=['val']) |
|
|
268 |
tf.summary.image('labels', labels[:, :, :, 32:33, 0], max_outputs=2,collections=['val']) |
|
|
269 |
|
|
|
270 |
labels = tf.reshape(labels, [-1,128*128*64]) |
|
|
271 |
|
|
|
272 |
encoded = graph_encoder(features) |
|
|
273 |
decoded = graph_decoder(encoded) |
|
|
274 |
|
|
|
275 |
decoded = tf.reshape(decoded, [-1, 128, 128, 64]) |
|
|
276 |
tf.summary.image('segmentation', decoded[:, :, :, 32:33], max_outputs=2, collections=['val']) |
|
|
277 |
|
|
|
278 |
output = tf.reshape(decoded, [-1,128*128*64]) |
|
|
279 |
|
|
|
280 |
# dice loss |
|
|
281 |
cost = tf.reduce_mean(tf.divide(smoothing + tf.multiply(2.0, tf.reduce_sum(output * labels, axis=-1)), |
|
|
282 |
tf.add(tf.reduce_sum(output, axis=-1), tf.reduce_sum(labels, axis=-1)))) |
|
|
283 |
|
|
|
284 |
tf.summary.scalar('training_loss', cost) |
|
|
285 |
tf.summary.scalar('val_loss', cost,collections=['val']) |
|
|
286 |
|
|
|
287 |
# for batchnorm |
|
|
288 |
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): |
|
|
289 |
opt = tf.train.AdamOptimizer(learning_rate) |
|
|
290 |
|
|
|
291 |
grads = tf.gradients(1-cost, tf.trainable_variables()) |
|
|
292 |
grads = list(zip(grads, tf.trainable_variables())) |
|
|
293 |
|
|
|
294 |
training_operation = opt.apply_gradients(grads_and_vars=grads) |
|
|
295 |
|
|
|
296 |
for grad, var in grads: |
|
|
297 |
tf.summary.histogram(var.name.replace(':',"_") + '/gradient', grad) |
|
|
298 |
tf.summary.histogram(var.name.replace(':',"_") , var) |
|
|
299 |
|
|
|
300 |
return cost, training_operation |
|
|
301 |
|
|
|
302 |
# running the session |
|
|
303 |
def train(): |
|
|
304 |
with tf.Session() as sess: |
|
|
305 |
|
|
|
306 |
cost, opt = model_fn() |
|
|
307 |
sess.run(tf.global_variables_initializer()) |
|
|
308 |
|
|
|
309 |
# merging tensorflow summaries |
|
|
310 |
merged = tf.summary.merge_all() |
|
|
311 |
merged_val = tf.summary.merge_all(key = 'val') |
|
|
312 |
|
|
|
313 |
train_writer = tf.summary.FileWriter('event/train',sess.graph) |
|
|
314 |
val_writer = tf.summary.FileWriter('event/val') |
|
|
315 |
|
|
|
316 |
saver = tf.train.Saver() |
|
|
317 |
|
|
|
318 |
for epoch in range(1, no_of_epochs+1): |
|
|
319 |
start_time = time.time() |
|
|
320 |
train_loss = [] |
|
|
321 |
examples = 0 |
|
|
322 |
|
|
|
323 |
# initializing iterator with training dataset |
|
|
324 |
sess.run([training_init_op]) |
|
|
325 |
|
|
|
326 |
while(True): |
|
|
327 |
try: |
|
|
328 |
# training procedure |
|
|
329 |
examples += 1 |
|
|
330 |
loss, _, summary = sess.run([cost, opt, merged], feed_dict={is_training: True}) |
|
|
331 |
train_writer.add_summary(summary,epoch) |
|
|
332 |
train_loss.append(loss) |
|
|
333 |
print('Epoch: {} - ex: {} - loss: {:.6f}'.format(epoch, examples*batch_size, sum(train_loss)/len(train_loss)), end="\r") |
|
|
334 |
except tf.errors.OutOfRangeError: |
|
|
335 |
val_loss = [] |
|
|
336 |
val_example = 0 |
|
|
337 |
|
|
|
338 |
# initializing iterator with validation dataset |
|
|
339 |
sess.run([validation_init_op]) |
|
|
340 |
|
|
|
341 |
while(True): |
|
|
342 |
try: |
|
|
343 |
val_example += 1 |
|
|
344 |
loss, summary_l = sess.run([cost, merged_val], feed_dict={is_training: False}) |
|
|
345 |
val_writer.add_summary(summary_l,epoch) |
|
|
346 |
val_loss.append(loss) |
|
|
347 |
print('Epoch: {} - ex: {} - val_loss: {:.6f}'.format(epoch, val_example*batch_size, sum(val_loss)/len(val_loss)), end="\r") |
|
|
348 |
|
|
|
349 |
except tf.errors.OutOfRangeError: |
|
|
350 |
break |
|
|
351 |
break |
|
|
352 |
|
|
|
353 |
print('Epoch: {}/{} - loss: {:.6f} - val_loss: {:.6f} - time: {:.4f}'.format(epoch, no_of_epochs, |
|
|
354 |
sum(train_loss)/len(train_loss), sum(val_loss)/len(val_loss), time.time()-start_time)) |
|
|
355 |
|
|
|
356 |
# saving weights |
|
|
357 |
if epoch%20==0: |
|
|
358 |
saver.save(sess, '/temp/weights_epoch_{0}.ckpt'.format(epoch)) |
|
|
359 |
|
|
|
360 |
if __name__ == '__main__': |
|
|
361 |
train() |