a b/train.py
1
import tensorflow as tf
2
import DataGenerator
3
import time
4
5
learning_rate = 1e-5
6
batch_size = 2
7
prefetch = 4
8
no_of_epochs = 10000
9
smoothing = 0.00001
10
11
# placeholder for training mode
12
is_training = tf.placeholder(tf.bool)
13
14
# input data generator
15
trainTransforms = [
16
    DataGenerator.RandomFlip(),
17
    DataGenerator.HistogramMatching(data_dir='train-data', train_size=40, prob=0.5),
18
    DataGenerator.RandomSmoothing(prob=0.3),
19
    DataGenerator.RandomNoise(prob=0.5),
20
    DataGenerator.Normalization()
21
    ]
22
23
valTransforms = [
24
    DataGenerator.Normalization()
25
    ]
26
27
TrainDataset = DataGenerator.DataGenerator(
28
    data_dir='train-data',
29
    transforms=trainTransforms,
30
    train=True
31
    )
32
33
ValDataset = DataGenerator.DataGenerator(
34
    data_dir='val-data',
35
    transforms=valTransforms,
36
    train=False
37
    )
38
39
trainDataset = TrainDataset.get_dataset()
40
trainDataset = trainDataset.shuffle(buffer_size=5)
41
trainDataset = trainDataset.batch(batch_size)
42
trainDataset = trainDataset.prefetch(prefetch)
43
44
valDataset = ValDataset.get_dataset()
45
valDataset = valDataset.shuffle(buffer_size=5)
46
valDataset = valDataset.batch(batch_size)
47
valDataset = valDataset.prefetch(prefetch)
48
49
iterator = tf.data.Iterator.from_structure(trainDataset.output_types, trainDataset.output_shapes)
50
51
training_init_op = iterator.make_initializer(trainDataset)
52
validation_init_op = iterator.make_initializer(valDataset)
53
next_item = iterator.get_next()
54
55
# convolution layer
56
def conv3d(x, no_of_input_channels, no_of_filters, filter_size, strides, padding, name):
57
    with tf.variable_scope(name) as scope:
58
        
59
        initializer = tf.variance_scaling_initializer()
60
        
61
        filter_size.extend([no_of_input_channels, no_of_filters])
62
        weights = tf.Variable(initializer(filter_size), name='weights')
63
        biases = tf.Variable(initializer([no_of_filters]), name='biases')
64
        conv = tf.nn.conv3d(x, weights, strides=strides, padding=padding, name=name)
65
        conv += biases
66
                
67
        return conv
68
69
# transposed convolution layer
70
def upsamp(x, no_of_kernels, name):
71
    with tf.variable_scope(name) as scope:
72
        upsamp = tf.layers.conv3d_transpose(x, no_of_kernels, [2,2,2], 2, padding='VALID', use_bias=True, reuse=tf.AUTO_REUSE)
73
        return upsamp
74
75
# PReLu layer
76
def prelu(x, scope=None):
77
    with tf.variable_scope(name_or_scope=scope, default_name="prelu", reuse=tf.AUTO_REUSE):
78
        alpha = tf.get_variable("prelu", shape=x.get_shape()[-1], dtype=x.dtype, initializer=tf.constant_initializer(0.1))
79
        prelu_out = tf.maximum(0.0, x) + alpha * tf.minimum(0.0, x)
80
        return prelu_out
81
    
82
# model graph
83
def graph_encoder(x):
84
        
85
    fine_grained_features = {}
86
    
87
    conv1 = conv3d(x,1,16,[3,3,3],[1,1,1,1,1],'SAME','Conv1_1')
88
    conv1 = conv3d(conv1,16,16,[3,3,3],[1,1,1,1,1],'SAME','Conv1_2')
89
    conv1 = tf.layers.batch_normalization(conv1, training=is_training)
90
    conv1 = prelu(conv1,'prelu1')
91
    
92
    res1 = tf.add(x,conv1)
93
    fine_grained_features['res1'] = res1
94
    
95
    down1 = conv3d(res1,16,32,[2,2,2],[1,2,2,2,1],'VALID','DownSampling1')
96
    down1 = prelu(down1,'down_prelu1')
97
    
98
    conv2 = conv3d(down1,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv2_1')
99
    conv2 = conv3d(conv2,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv2_2')
100
    conv2= tf.layers.batch_normalization(conv2, training=is_training)
101
    conv2 = prelu(conv2,'prelu2')
102
    
103
    conv3 = conv3d(conv2,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv3_1')
104
    conv3 = conv3d(conv3,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv3_2')
105
    conv3 = tf.layers.batch_normalization(conv3, training=is_training)
106
    conv3 = prelu(conv3,'prelu3')
107
    
108
    res2 = tf.add(down1,conv3)
109
    fine_grained_features['res2'] = res2
110
111
    down2 = conv3d(res2,32,64,[2,2,2],[1,2,2,2,1],'VALID','DownSampling2')
112
    down2 = prelu(down2,'down_prelu2')
113
    
114
    conv4 = conv3d(down2,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv4_1')
115
    conv4 = conv3d(conv4,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv4_2')
116
    conv4 = tf.layers.batch_normalization(conv4, training=is_training)
117
    conv4 = prelu(conv4,'prelu4')
118
    
119
    conv5 = conv3d(conv4,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv5_1')
120
    conv5 = conv3d(conv5,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv5_2')
121
    conv5 = tf.layers.batch_normalization(conv5, training=is_training)
122
    conv5 = prelu(conv5,'prelu5')
123
    
124
    conv6 = conv3d(conv5,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv6_1')
125
    conv6 = conv3d(conv6,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv6_2')
126
    conv6 = tf.layers.batch_normalization(conv6, training=is_training)
127
    conv6 = prelu(conv6,'prelu6')
128
    
129
    res3 = tf.add(down2,conv6)
130
    fine_grained_features['res3'] = res3
131
132
    down3 = conv3d(res3,64,128,[2,2,2],[1,2,2,2,1],'VALID','DownSampling3')
133
    down3 = prelu(down3,'down_prelu3')
134
    
135
    conv7 = conv3d(down3,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv7_1')
136
    conv7 = conv3d(conv7,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv7_2')
137
    conv7 = tf.layers.batch_normalization(conv7, training=is_training)
138
    conv7 = prelu(conv7,'prelu7')
139
    
140
    conv8 = conv3d(conv7,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv8_1')
141
    conv8 = conv3d(conv8,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv8_2')
142
    conv8 = tf.layers.batch_normalization(conv8, training=is_training)
143
    conv8 = prelu(conv8,'prelu8')
144
    
145
    conv9 = conv3d(conv8,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv9_1')
146
    conv9 = conv3d(conv9,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv9_2')
147
    conv9 = tf.layers.batch_normalization(conv9, training=is_training)
148
    conv9 = prelu(conv9,'prelu9')
149
    
150
    res4 = tf.add(down3,conv9)
151
    fine_grained_features['res4'] = res4
152
153
    down4 = conv3d(res4,128,256,[2,2,2],[1,2,2,2,1],'VALID','DownSampling4')
154
    down4 = prelu(down4,'down_prelu4')
155
    
156
    conv10 = conv3d(down4,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv10_1')
157
    conv10 = conv3d(conv10,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv10_2')
158
    conv10 = tf.layers.batch_normalization(conv10, training=is_training)
159
    conv10 = prelu(conv10,'prelu10')
160
    
161
    conv11 = conv3d(conv10,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv11_1')
162
    conv11 = conv3d(conv11,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv11_2')
163
    conv11 = tf.layers.batch_normalization(conv11, training=is_training)
164
    conv11 = prelu(conv11,'prelu11')
165
    
166
    conv12 = conv3d(conv11,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv12_1')
167
    conv12 = conv3d(conv12,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv12_2')
168
    conv12 = tf.layers.batch_normalization(conv12, training=is_training)
169
    conv12 = prelu(conv12,'prelu12')
170
    
171
    res5 = tf.add(down4,conv12)
172
    fine_grained_features['res5'] = res5
173
    
174
    return fine_grained_features
175
176
def graph_decoder(features):
177
        
178
    inp = features['res5']
179
    
180
    upsamp1 = upsamp(inp,128,'Upsampling1')
181
    upsamp1 = prelu(upsamp1,'prelu_upsamp1')
182
    
183
    concat1 = tf.concat([upsamp1,features['res4']],axis=4)
184
    
185
    conv13 = conv3d(concat1,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv13_1')
186
    conv13 = conv3d(conv13,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv13_2')
187
    conv13 = tf.layers.batch_normalization(conv13, training=is_training)
188
    conv13 = prelu(conv13,'prelu13')
189
    
190
    conv14 = conv3d(conv13,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv14_1')
191
    conv14 = conv3d(conv14,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv14_2')
192
    conv14 = tf.layers.batch_normalization(conv14, training=is_training)
193
    conv14 = prelu(conv14,'prelu14')
194
    
195
    conv15 = conv3d(conv14,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv15_1')
196
    conv15 = conv3d(conv15,256,256,[3,3,3],[1,1,1,1,1],'SAME','Conv15_2')
197
    conv15 = tf.layers.batch_normalization(conv15, training=is_training)
198
    conv15 = prelu(conv15,'prelu15')
199
    
200
    res6 = tf.add(concat1,conv15)
201
    
202
    upsamp2 = upsamp(res6,64,'Upsampling2')
203
    upsamp2 = prelu(upsamp2,'prelu_upsamp2')
204
    
205
    concat2 = tf.concat([upsamp2,features['res3']],axis=4)
206
    
207
    conv16 = conv3d(concat2,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv16_1')
208
    conv16 = conv3d(conv16,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv16_2')
209
    conv16 = tf.layers.batch_normalization(conv16, training=is_training)
210
    conv16 = prelu(conv16,'prelu16')
211
    
212
    conv17 = conv3d(conv16,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv17_1')
213
    conv17 = conv3d(conv17,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv17_2')
214
    conv17 = tf.layers.batch_normalization(conv17, training=is_training)
215
    conv17 = prelu(conv17,'prelu17')
216
    
217
    conv18 = conv3d(conv17,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv18_1')
218
    conv18 = conv3d(conv18,128,128,[3,3,3],[1,1,1,1,1],'SAME','Conv18_2')
219
    conv18 = tf.layers.batch_normalization(conv18, training=is_training)
220
    conv18 = prelu(conv18,'prelu18')
221
    
222
    res7 = tf.add(concat2,conv18)
223
    
224
    upsamp3 = upsamp(res7,32,'Upsampling3')
225
    upsamp3 = prelu(upsamp3,'prelu_upsamp3')
226
    
227
    concat3 = tf.concat([upsamp3,features['res2']],axis=4)
228
    
229
    conv19 = conv3d(concat3,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv19_1')
230
    conv19 = conv3d(conv19,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv19_2')
231
    conv19 = tf.layers.batch_normalization(conv19, training=is_training)
232
    conv19 = prelu(conv19,'prelu19')
233
    
234
    conv20 = conv3d(conv19,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv20_1')
235
    conv20 = conv3d(conv20,64,64,[3,3,3],[1,1,1,1,1],'SAME','Conv20_2')
236
    conv20 = tf.layers.batch_normalization(conv20, training=is_training)
237
    conv20 = prelu(conv20,'prelu20')
238
    
239
    res8 = tf.add(concat3,conv20)
240
    
241
    upsamp4 = upsamp(res8,16,'Upsampling4')
242
    upsamp4 = prelu(upsamp4,'prelu_upsamp4')
243
    
244
    concat4 = tf.concat([upsamp4,features['res1']],axis=4)
245
    
246
    conv21 = conv3d(concat4,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv21_1')
247
    conv21 = conv3d(conv21,32,32,[3,3,3],[1,1,1,1,1],'SAME','Conv21_2')
248
    conv21 = tf.layers.batch_normalization(conv21, training=is_training)
249
    conv21 = prelu(conv21,'prelu21')
250
    
251
    res9 = tf.add(concat4,conv21)
252
    
253
    conv22 = conv3d(res9,32,1,[1,1,1],[1,1,1,1,1],'SAME','Conv22')
254
    conv22 = tf.nn.sigmoid(conv22,'sigmoid')
255
    
256
    return conv22
257
258
# loss and optimizer
259
def model_fn():
260
    
261
    features, labels = next_item
262
        
263
    features = tf.reshape(features, [-1, 128, 128, 64, 1])
264
    labels = tf.cast(tf.reshape(labels, [-1, 128, 128, 64, 1]), dtype=tf.float32)
265
    
266
    # writing summaries to tensorboard
267
    tf.summary.image('features', features[:, :, :, 32:33, 0], max_outputs=2,collections=['val'])
268
    tf.summary.image('labels', labels[:, :, :, 32:33, 0], max_outputs=2,collections=['val'])
269
    
270
    labels = tf.reshape(labels, [-1,128*128*64])
271
    
272
    encoded = graph_encoder(features)
273
    decoded = graph_decoder(encoded)
274
275
    decoded = tf.reshape(decoded, [-1, 128, 128, 64])
276
    tf.summary.image('segmentation', decoded[:, :, :, 32:33], max_outputs=2, collections=['val'])
277
    
278
    output =  tf.reshape(decoded, [-1,128*128*64])
279
    
280
    # dice loss
281
    cost = tf.reduce_mean(tf.divide(smoothing + tf.multiply(2.0, tf.reduce_sum(output * labels, axis=-1)),
282
                 tf.add(tf.reduce_sum(output, axis=-1), tf.reduce_sum(labels, axis=-1))))
283
    
284
    tf.summary.scalar('training_loss', cost)
285
    tf.summary.scalar('val_loss', cost,collections=['val'])
286
    
287
    # for batchnorm
288
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
289
        opt = tf.train.AdamOptimizer(learning_rate)
290
291
        grads = tf.gradients(1-cost, tf.trainable_variables())
292
        grads = list(zip(grads, tf.trainable_variables()))
293
294
        training_operation = opt.apply_gradients(grads_and_vars=grads)
295
    
296
    for grad, var in grads:
297
        tf.summary.histogram(var.name.replace(':',"_") + '/gradient', grad)
298
        tf.summary.histogram(var.name.replace(':',"_") , var)
299
    
300
    return cost, training_operation
301
302
# running the session 
303
def train():
304
    with tf.Session() as sess:
305
        
306
        cost, opt = model_fn()
307
        sess.run(tf.global_variables_initializer())
308
309
        # merging tensorflow summaries
310
        merged = tf.summary.merge_all()
311
        merged_val = tf.summary.merge_all(key = 'val')
312
313
        train_writer = tf.summary.FileWriter('event/train',sess.graph)
314
        val_writer = tf.summary.FileWriter('event/val')
315
        
316
        saver = tf.train.Saver()
317
318
        for epoch in range(1, no_of_epochs+1):
319
            start_time = time.time()
320
            train_loss = []
321
            examples = 0
322
323
            # initializing iterator with training dataset
324
            sess.run([training_init_op])
325
326
            while(True):
327
                try:
328
                    # training procedure
329
                    examples += 1
330
                    loss, _, summary = sess.run([cost, opt, merged], feed_dict={is_training: True})
331
                    train_writer.add_summary(summary,epoch)
332
                    train_loss.append(loss)
333
                    print('Epoch: {} - ex: {} - loss: {:.6f}'.format(epoch, examples*batch_size, sum(train_loss)/len(train_loss)), end="\r")
334
                except tf.errors.OutOfRangeError:
335
                    val_loss = []
336
                    val_example = 0
337
338
                    # initializing iterator with validation dataset
339
                    sess.run([validation_init_op])
340
341
                    while(True):
342
                        try:
343
                            val_example += 1
344
                            loss, summary_l = sess.run([cost, merged_val], feed_dict={is_training: False})
345
                            val_writer.add_summary(summary_l,epoch)
346
                            val_loss.append(loss)
347
                            print('Epoch: {} - ex: {} - val_loss: {:.6f}'.format(epoch, val_example*batch_size, sum(val_loss)/len(val_loss)), end="\r")
348
349
                        except tf.errors.OutOfRangeError:
350
                            break
351
                    break
352
            
353
            print('Epoch: {}/{} - loss: {:.6f} - val_loss: {:.6f} - time: {:.4f}'.format(epoch, no_of_epochs, 
354
                sum(train_loss)/len(train_loss), sum(val_loss)/len(val_loss), time.time()-start_time))
355
            
356
            # saving weights
357
            if epoch%20==0:
358
                saver.save(sess, '/temp/weights_epoch_{0}.ckpt'.format(epoch))
359
360
if __name__ == '__main__':
361
    train()