a b/Segmentation/unet_context_sag.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
Created on Tue Nov 29 19:33:11 2018
5
6
@author: Josefine
7
"""
8
9
## Import libraries
10
import numpy as np 
11
import tensorflow as tf
12
import re
13
import glob
14
import keras
15
from time import time
16
from sklearn.utils import shuffle
17
from skimage.transform import resize
18
19
# Define parameters:
20
lr          = 1e-5    # learning-rate
21
nEpochs     = 30         # Number of epochs
22
23
# Other network specific parameters
24
n_classes = 8
25
beta1 = 0.9
26
beta2 = 0.999
27
epsilon = 1e-8
28
29
imgDim = 256
30
labelDim = 256
31
######################################################################
32
##                                                                  ##
33
##                   Setting up the network                         ##
34
##                                                                  ##
35
######################################################################
36
37
tf.reset_default_graph()
38
39
#Define placeholder for input and output
40
x = tf.placeholder(tf.float32,[None,imgDim,imgDim,1],name = 'x_train') #input (572+572+1 image)
41
x_contextual = tf.placeholder(tf.float32,[None,imgDim,imgDim,9],name = 'x_train_context') #input (572+572+1 image)
42
y = tf.placeholder(tf.float32,[None,labelDim,labelDim,n_classes],name='y_train') #Output (388x388x2 labels)
43
drop_rate = tf.placeholder(tf.float32, shape=())
44
45
######################################################################
46
##                                                                  ##
47
##                   Metrics and functions                          ##
48
##                                                                  ##
49
######################################################################
50
51
def natural_sort(l): 
52
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
53
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
54
    return sorted(l, key = alphanum_key)
55
56
#def dice_coef(y, output): #making the loss function smooth
57
#    y_true_f = tf.contrib.layers.flatten(tf.argmax(y,axis=-1))
58
#    y_pred_f = tf.contrib.layers.flatten(tf.argmax(output,axis=-1))
59
#    intersection = tf.reduce_sum(y_true_f * y_pred_f)
60
#    return (2 * intersection) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f))
61
62
######################################################################
63
##                               Layers                             ##
64
######################################################################
65
def conv2d(inputs, filters, kernel, stride, pad, name):
66
    """ Creates a 2D convolution with following specs:
67
    Args:
68
        inputs:         (Tensor)            Tensor which you want to apply convolution to 
69
        filters:        (integer)           Number of filters in kernel
70
        kernel_size:    (integer)           Size of kernel
71
        Strides:        (integer)           Stride
72
        pad:            ('VALID' or 'SAME') Type of padding
73
        name:           (string)            Name of layer
74
    """
75
    with tf.name_scope(name):
76
        conv = tf.layers.conv2d(inputs, filters, kernel_size = kernel, strides = [stride,stride], padding=pad,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer())
77
        return conv  
78
79
def max_pool(inputs,n,stride,pad):
80
    maxpool = tf.nn.max_pool(inputs, ksize=[1,n,n,1], strides=[1,stride,stride,1], padding=pad)
81
    return maxpool
82
83
def dropout(input1,drop_rate):
84
    input_shape = input1.get_shape().as_list()
85
    noise_shape = tf.constant(value=[1, 1, 1, input_shape[3]])
86
    drop = tf.nn.dropout(input1, keep_prob=drop_rate, noise_shape=noise_shape)
87
    return drop
88
89
def crop2d(inputs,dim):
90
    crop = tf.image.resize_image_with_crop_or_pad(inputs,dim,dim)
91
    return crop
92
93
def concat(input1,input2,axis):
94
    combined = tf.concat([input1,input2],axis)
95
    return combined
96
97
def transpose(inputs,filters, kernel, stride, pad, name):
98
    with tf.name_scope(name):
99
        trans = tf.layers.conv2d_transpose(inputs,filters, kernel_size=[kernel,kernel],strides=[stride,stride],padding=pad,kernel_initializer=tf.contrib.layers.xavier_initializer())
100
        return trans
101
    
102
######################################################################
103
##                             Data                                 ##
104
###################################################################### 
105
    
106
def create_data(filename_img,direction):
107
    images = []
108
    file = np.load(filename_img)
109
    a = file['images']
110
    # Reshape:
111
    im = resize(a,(labelDim,labelDim,labelDim),order=0)
112
    if direction == 'axial':
113
        for i in range(im.shape[0]):
114
            images.append((im[i,:,:]))
115
    if direction == 'sag':
116
        for i in range(im.shape[1]):
117
            images.append((im[:,i,:]))
118
    if direction == 'cor':
119
        for i in range(im.shape[2]):
120
            images.append((im[:,:,i]))    
121
    images = np.asarray(images)
122
    images = images.reshape(-1, imgDim,imgDim,1)
123
124
    # Label creation
125
    labels = []
126
    b = file['labels']        
127
    lab = resize(b,(labelDim,labelDim,labelDim),order=0)
128
    if direction == 'axial':
129
        for i in range(lab.shape[0]):
130
            labels.append((lab[i,:,:]))
131
    if direction == 'sag':
132
        for i in range(lab.shape[1]):
133
            labels.append((lab[:,i,:]))
134
    if direction == 'cor':
135
        for i in range(lab.shape[2]):
136
            labels.append((lab[:,:,i]))            
137
    labels = np.asarray(labels)
138
    labels_onehot = np.stack((labels==0, labels==500, labels==600, labels==420, labels ==550, labels==205, labels ==820, labels==850), axis=3)
139
140
    return images, labels_onehot
141
142
143
###############################################################################
144
##                            Setup of network                               ##
145
###############################################################################
146
147
# -------------------------- Contracting path ---------------------------------
148
conv1a = conv2d(x,filters=64,kernel=3,stride=1,pad='same',name = 'conv1a')
149
conv1a.get_shape()
150
conv1b = conv2d(conv1a,filters=64,kernel=3,stride=1,pad='same',name = 'conv1b')
151
conv1b.get_shape()
152
#drop1 = tf.nn.dropout(conv1b, keep_prob=drop_rate) 
153
#drop1.get_shape()
154
pool1 = max_pool(conv1b,n=2,stride=2,pad='SAME')
155
pool1.get_shape()
156
157
conv2a = conv2d(pool1,filters=128,kernel=3,stride=1,pad='same',name = 'conv2a')
158
conv2a.get_shape()
159
conv2b = conv2d(conv2a,filters=128,kernel=3,stride=1,pad='same',name = 'conv2b')
160
conv2b.get_shape()
161
drop2 = dropout(conv2b, drop_rate) 
162
drop2.get_shape()
163
pool2 = max_pool(drop2,n=2,stride=2,pad='SAME')
164
pool2.get_shape()
165
166
conv3a = conv2d(pool2,filters=256,kernel=3,stride=1,pad='same',name = 'conv3a')
167
conv3a.get_shape()
168
conv3b = conv2d(conv3a,filters=256,kernel=3,stride=1,pad='same',name = 'conv3b')
169
conv3b.get_shape()
170
drop3 = dropout(conv3b, drop_rate) 
171
drop3.get_shape()
172
pool3 = max_pool(drop3,n=2,stride=2,pad='SAME')
173
pool3.get_shape()
174
175
conv4a = conv2d(pool3,filters=512,kernel=3,stride=1,pad='same',name = 'conv4a')
176
conv4a.get_shape()
177
conv4b = conv2d(conv4a,filters=512,kernel=3,stride=1,pad='same',name = 'conv4b')
178
conv4b.get_shape()
179
drop4 = dropout(conv4b, drop_rate) 
180
drop4.get_shape()
181
pool4 = max_pool(drop4,n=2,stride=2,pad='SAME')
182
pool4.get_shape()
183
184
# -------------------------- Contextual input path ----------------------------
185
186
conv1a_2 = conv2d(x_contextual,filters=64,kernel=3,stride=1,pad='same',name = 'conv1a2')
187
conv1b_2 = conv2d(conv1a_2,filters=64,kernel=3,stride=1,pad='same',name = 'conv1b2')
188
#drop1_2 = tf.nn.dropout(conv1b_2, keep_prob=drop_rate) 
189
pool1_2 = max_pool(conv1b_2,n=2,stride=2,pad='SAME')
190
191
conv2a_2 = conv2d(pool1_2,filters=128,kernel=3,stride=1,pad='same',name = 'conv2a2')
192
conv2b_2 = conv2d(conv2a_2,filters=128,kernel=3,stride=1,pad='same',name = 'conv2b2')
193
drop2_2 = dropout(conv2b_2, drop_rate) 
194
pool2_2 = max_pool(drop2_2,n=2,stride=2,pad='SAME')
195
196
conv3a_2 = conv2d(pool2_2,filters=256,kernel=3,stride=1,pad='same',name = 'conv3a2')
197
conv3b_2 = conv2d(conv3a_2,filters=256,kernel=3,stride=1,pad='same',name = 'conv3b2')
198
drop3_2 = dropout(conv3b_2, drop_rate)  
199
pool3_2 = max_pool(drop3_2,n=2,stride=2,pad='SAME')
200
201
conv4a_2 = conv2d(pool3_2,filters=512,kernel=3,stride=1,pad='same',name = 'conv4a2')
202
conv4b_2 = conv2d(conv4a_2,filters=512,kernel=3,stride=1,pad='same',name = 'conv4b2')
203
drop4_2 = dropout(conv4b_2, drop_rate) 
204
pool4_2 = max_pool(drop4_2,n=2,stride=2,pad='SAME')
205
206
# ---------------------------- Expansive path ---------------------------------
207
combx = concat(pool4,pool4_2,axis=3)
208
conv5a = conv2d(combx,filters=1024,kernel=3,stride=1,pad='same',name = 'conv5a')
209
conv5a.get_shape()
210
conv5b = conv2d(conv5a,filters=1024,kernel=3,stride=1,pad='same',name = 'conv5b')
211
conv5b.get_shape()
212
drop5 = dropout(conv5b, drop_rate) 
213
drop5.get_shape()
214
up6a = transpose(drop5,filters=512,kernel=2,stride=2,pad='same',name='up6a')
215
up6a.get_shape()
216
up6b = concat(up6a,conv4b,axis=3)
217
up6b.get_shape()
218
219
conv7a = conv2d(up6b,filters=512,kernel=3,stride=1,pad='same',name = 'conv7a')
220
conv7a.get_shape()
221
conv7b = conv2d(conv7a,filters=512,kernel=3,stride=1,pad='same',name = 'conv7b')
222
conv7b.get_shape()
223
drop7 = dropout(conv7b, drop_rate) 
224
drop7.get_shape()
225
up7a = transpose(drop7,filters=256,kernel=2,stride=2,pad='same',name='up7a')
226
up7a.get_shape()
227
up7b = concat(up7a,conv3b,axis=3)
228
up7b.get_shape()
229
230
conv8a = conv2d(up7b,filters=256,kernel=3,stride=1,pad='same',name = 'conv7a')
231
conv8a.get_shape()
232
conv8b = conv2d(conv8a,filters=256,kernel=3,stride=1,pad='same',name = 'conv7b')
233
conv8b.get_shape()
234
drop8 = dropout(conv8b, drop_rate) 
235
drop8.get_shape()
236
up8a = transpose(drop8,filters=128,kernel=2,stride=2,pad='same',name='up7a')
237
up8a.get_shape()
238
up8b = concat(up8a,conv2b,axis=3)
239
up8b.get_shape()
240
241
conv9a = conv2d(up8b,filters=128,kernel=3,stride=1,pad='same',name = 'conv7a')
242
conv9a.get_shape()
243
conv9b = conv2d(conv9a,filters=128,kernel=3,stride=1,pad='same',name = 'conv7b')
244
conv9b.get_shape()
245
#drop9 = tf.nn.dropout(conv9b, keep_prob=drop_rate) 
246
#drop9.get_shape()
247
up9a = transpose(conv9b,filters=64,kernel=2,stride=2,pad='same',name='up7a')
248
up9a.get_shape()
249
up9b = concat(up9a,conv1b,axis=3)
250
up9b.get_shape()
251
252
conv10a = conv2d(up9b,filters=64,kernel=3,stride=1,pad='same',name = 'conv7a')
253
conv10a.get_shape()
254
conv10b = conv2d(conv10a,filters=64,kernel=3,stride=1,pad='same',name = 'conv7b')
255
conv10b.get_shape()
256
257
output = tf.layers.conv2d(conv10b, n_classes, 1, (1,1),padding ='same',activation=tf.nn.softmax, kernel_initializer=tf.contrib.layers.xavier_initializer(), name = 'output')
258
output.get_shape()
259
260
######################################################################
261
##                                                                  ##
262
##                            Loading data                          ##
263
##                                                                  ##
264
######################################################################
265
266
filelist_train = natural_sort(glob.glob('WHS/Data/train_segments_*.npz')) # list of file names
267
x_train = {}
268
y_train = {}
269
keys = range(len(filelist_train))
270
for i in keys:
271
    x_train[i] = np.zeros([imgDim,imgDim,imgDim,1])
272
    y_train[i] = np.zeros([imgDim,imgDim,imgDim,8])
273
274
for i in range(len(filelist_train)):
275
    img, lab = create_data(filelist_train[i],'sag')
276
    x_train[i] = img
277
    y_train[i] = lab    
278
279
#filelist_val = natural_sort(glob.glob('WHS/Data/validation_segments_*.npz')) # list of file names
280
#x_val = {}
281
#y_val = {}
282
#keys = range(len(filelist_val))
283
#for i in keys:
284
#    x_val[i] = np.zeros([imgDim,imgDim,imgDim,1])
285
#    y_val[i] = np.zeros([imgDim,imgDim,imgDim,8])
286
#
287
#for i in range(len(filelist_val)):
288
#    img, lab = create_data(filelist_val[i],'sag')
289
#    x_val[i] = img
290
#    y_val[i] = lab    
291
#        
292
######################################################################
293
##                                                                  ##
294
##                   Defining the training                          ##
295
##                                                                  ##
296
######################################################################
297
298
# Training-steps (honestly I have no idea what it does...)
299
global_step = tf.Variable(0,trainable=False)
300
301
###############################################################################
302
##                               Loss                                        ##
303
###############################################################################
304
# Compare the output of the network (output: tensor) with the ground truth (y: tensor/placeholder)
305
# In this case we use sigmoid cross entropu losss with logits
306
loss = tf.reduce_mean(keras.losses.categorical_crossentropy(y_true = y, y_pred = output))
307
correct_prediction = tf.equal(tf.argmax(output, axis=-1), tf.argmax(y, axis=-1))
308
309
# averaging the one-hot encoded vector
310
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
311
#dice = dice_coef(y, output,smooth=1)
312
313
# Create contextual output:
314
pred = tf.argmax(tf.nn.softmax(output[0,:,:,:]),axis=-1)
315
predict = tf.one_hot(pred,8)
316
context = tf.concat([x[0,:,:,:],predict],axis=-1)
317
318
###############################################################################
319
##                               Optimizer                                   ##
320
###############################################################################
321
opt = tf.train.AdamOptimizer(lr,beta1,beta2,epsilon)
322
323
###############################################################################
324
##                               Minimizer                                   ##
325
###############################################################################
326
train_adam = opt.minimize(loss, global_step)
327
328
###############################################################################
329
##                               Initializer                                 ##
330
###############################################################################
331
# Initializes all variables in the graph
332
init = tf.global_variables_initializer()
333
334
######################################################################
335
##                                                                  ##
336
##                   Start training                                 ## 
337
##                                                                  ##
338
######################################################################
339
# Initialize saving of the network parameters:
340
saver = tf.train.Saver()
341
342
######################## Start training Session ###########################
343
start_time = time()
344
valid_loss, valid_accuracy = [], []
345
train_loss, train_accuracy = [], []
346
347
c = np.zeros([imgDim+1,imgDim,imgDim,9])
348
predictions = {}
349
keys = range(len(filelist_train))
350
for i in keys:
351
    predictions[i] = c
352
353
#predictions_val = {}
354
#keys = range(len(filelist_val))
355
#for i in keys:
356
#    predictions_val[i] = c
357
358
index_volumeID = np.repeat(range(len(x_train)),imgDim)
359
index_imageID = np.tile(range(imgDim),len(x_train))
360
index_comb = np.vstack((index_volumeID,index_imageID)).T
361
362
index_shuffle = shuffle(index_comb)
363
with tf.Session() as sess:
364
    # Initialize
365
    t_start = time()
366
367
    sess.run(init)    
368
    
369
    # Trainingsloop
370
    for epoch in range(nEpochs):
371
        t_epoch_start = time()
372
        print('========Training Epoch: ', (epoch + 1))
373
        iter_by_epoch = len(index_shuffle)            
374
        for i in range(iter_by_epoch):
375
            t_iter_start = time()
376
            x_batch = np.expand_dims(x_train[index_shuffle[i,0]][index_shuffle[i,1],:,:,:], axis=0)
377
            x_batch_context = np.expand_dims(predictions[index_shuffle[i,0]][index_shuffle[i,1],:,:,:], axis=0)
378
            y_batch = np.expand_dims(y_train[index_shuffle[i,0]][index_shuffle[i,1],:,:,:], axis=0)
379
            _,_loss,_acc,pred_out = sess.run([train_adam, loss, accuracy,context], feed_dict={x: x_batch, x_contextual: x_batch_context, y: y_batch, drop_rate: 0.5})   
380
            predictions[index_shuffle[i,0]][index_shuffle[i,1]+1,:,:,:] = pred_out
381
            train_loss.append(_loss)
382
            train_accuracy.append(_acc)
383
384
#            # Validation-step:
385
#            if i==np.max(range(iter_by_epoch)):
386
#                for n in range(len(x_val)):
387
#                    for m in range(imgDim):
388
#                        x_batch_val = np.expand_dims(x_val[n][m,:,:,:], axis=0)
389
#                        y_batch_val = np.expand_dims(y_val[n][m,:,:,:], axis=0)
390
#                        x_context_val = np.expand_dims(predictions_val[n][m,:,:,:], axis=0)
391
#                        acc_val, loss_val,out_context = sess.run([accuracy,loss,context], feed_dict={x: x_batch_val, x_contextual: x_context_val, y: y_batch_val, drop_rate: 1.0})
392
#                        predictions_val[n][m+1,:,:,:] = pred_out
393
#                        valid_loss.append(loss_val)
394
#                        valid_accuracy.append(acc_val)                        
395
#       
396
        t_epoch_finish = time() 
397
        print("Epoch:", (epoch + 1), '  avg_loss= ', "{:.9f}".format(np.mean(train_loss)), '  avg_acc= ', "{:.9f}".format(np.mean(train_accuracy)),' time_epoch=', str(t_epoch_finish-t_epoch_start))
398
#        print("Validation:", (epoch + 1), '  avg_loss= ', "{:.9f}".format(np.mean(valid_loss)), '  avg_acc= ', "{:.9f}".format(np.mean(valid_accuracy)))
399
400
    t_end = time()
401
402
    saver.save(sess,"WHS/Results/segmentation/model_sag/model.ckpt")
403
    np.save('WHS/Results/train_hist/segmentation/train_loss_sag',train_loss)
404
    np.save('WHS/Results/train_hist/segmentation/train_acc_sag',train_accuracy)
405
#    np.save('WHS/Results/train_hist/segmentation/valid_loss_sag',valid_loss)
406
#    np.save('WHS/Results/train_hist/segmentation/valid_acc_sag',valid_accuracy)
407
    print('Training Done! Total time:' + str(t_end - t_start))#!/usr/bin/env python3