#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 29 16:47:12 2018
@author: Josefine
"""
## Import libraries
import numpy as np
import tensorflow as tf
import re
import glob
import keras
from time import time
from sklearn.utils import shuffle
from skimage.transform import resize
# Define parameters:
lr = 1e-5 # learning-rate
nEpochs = 30 # Number of epochs
# Other network specific parameters
n_classes = 8
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
imgDim = 256
labelDim = 256
######################################################################
## ##
## Setting up the network ##
## ##
######################################################################
tf.reset_default_graph()
#Define placeholder for input and output
x = tf.placeholder(tf.float32,[None,imgDim,imgDim,1],name = 'x_train') #input (572+572+1 image)
x_contextual = tf.placeholder(tf.float32,[None,imgDim,imgDim,9],name = 'x_train_context') #input (572+572+1 image)
y = tf.placeholder(tf.float32,[None,labelDim,labelDim,n_classes],name='y_train') #Output (388x388x2 labels)
drop_rate = tf.placeholder(tf.float32, shape=())
######################################################################
## ##
## Metrics and functions ##
## ##
######################################################################
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
return sorted(l, key = alphanum_key)
#def dice_coef(y, output): #making the loss function smooth
# y_true_f = tf.contrib.layers.flatten(tf.argmax(y,axis=-1))
# y_pred_f = tf.contrib.layers.flatten(tf.argmax(output,axis=-1))
# intersection = tf.reduce_sum(y_true_f * y_pred_f)
# return (2 * intersection) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f))
######################################################################
## Layers ##
######################################################################
def conv2d(inputs, filters, kernel, stride, pad, name):
""" Creates a 2D convolution with following specs:
Args:
inputs: (Tensor) Tensor which you want to apply convolution to
filters: (integer) Number of filters in kernel
kernel_size: (integer) Size of kernel
Strides: (integer) Stride
pad: ('VALID' or 'SAME') Type of padding
name: (string) Name of layer
"""
with tf.name_scope(name):
conv = tf.layers.conv2d(inputs, filters, kernel_size = kernel, strides = [stride,stride], padding=pad,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer())
return conv
def max_pool(inputs,n,stride,pad):
maxpool = tf.nn.max_pool(inputs, ksize=[1,n,n,1], strides=[1,stride,stride,1], padding=pad)
return maxpool
def dropout(input1,drop_rate):
input_shape = input1.get_shape().as_list()
noise_shape = tf.constant(value=[1, 1, 1, input_shape[3]])
drop = tf.nn.dropout(input1, keep_prob=drop_rate, noise_shape=noise_shape)
return drop
def crop2d(inputs,dim):
crop = tf.image.resize_image_with_crop_or_pad(inputs,dim,dim)
return crop
def concat(input1,input2,axis):
combined = tf.concat([input1,input2],axis)
return combined
def transpose(inputs,filters, kernel, stride, pad, name):
with tf.name_scope(name):
trans = tf.layers.conv2d_transpose(inputs,filters, kernel_size=[kernel,kernel],strides=[stride,stride],padding=pad,kernel_initializer=tf.contrib.layers.xavier_initializer())
return trans
######################################################################
## Data ##
######################################################################
def create_data(filename_img,direction):
images = []
file = np.load(filename_img)
a = file['images']
# Reshape:
im = resize(a,(labelDim,labelDim,labelDim),order=0)
if direction == 'axial':
for i in range(im.shape[0]):
images.append((im[i,:,:]))
if direction == 'sag':
for i in range(im.shape[1]):
images.append((im[:,i,:]))
if direction == 'cor':
for i in range(im.shape[2]):
images.append((im[:,:,i]))
images = np.asarray(images)
images = images.reshape(-1, imgDim,imgDim,1)
# Label creation
labels = []
b = file['labels']
lab = resize(b,(labelDim,labelDim,labelDim),order=0)
if direction == 'axial':
for i in range(lab.shape[0]):
labels.append((lab[i,:,:]))
if direction == 'sag':
for i in range(lab.shape[1]):
labels.append((lab[:,i,:]))
if direction == 'cor':
for i in range(lab.shape[2]):
labels.append((lab[:,:,i]))
labels = np.asarray(labels)
labels_onehot = np.stack((labels==0, labels==500, labels==600, labels==420, labels ==550, labels==205, labels ==820, labels==850), axis=3)
return images, labels_onehot
###############################################################################
## Setup of network ##
###############################################################################
# -------------------------- Contracting path ---------------------------------
conv1a = conv2d(x,filters=64,kernel=3,stride=1,pad='same',name = 'conv1a')
conv1a.get_shape()
conv1b = conv2d(conv1a,filters=64,kernel=3,stride=1,pad='same',name = 'conv1b')
conv1b.get_shape()
#drop1 = tf.nn.dropout(conv1b, keep_prob=drop_rate)
#drop1.get_shape()
pool1 = max_pool(conv1b,n=2,stride=2,pad='SAME')
pool1.get_shape()
conv2a = conv2d(pool1,filters=128,kernel=3,stride=1,pad='same',name = 'conv2a')
conv2a.get_shape()
conv2b = conv2d(conv2a,filters=128,kernel=3,stride=1,pad='same',name = 'conv2b')
conv2b.get_shape()
drop2 = dropout(conv2b, drop_rate)
drop2.get_shape()
pool2 = max_pool(drop2,n=2,stride=2,pad='SAME')
pool2.get_shape()
conv3a = conv2d(pool2,filters=256,kernel=3,stride=1,pad='same',name = 'conv3a')
conv3a.get_shape()
conv3b = conv2d(conv3a,filters=256,kernel=3,stride=1,pad='same',name = 'conv3b')
conv3b.get_shape()
drop3 = dropout(conv3b, drop_rate)
drop3.get_shape()
pool3 = max_pool(drop3,n=2,stride=2,pad='SAME')
pool3.get_shape()
conv4a = conv2d(pool3,filters=512,kernel=3,stride=1,pad='same',name = 'conv4a')
conv4a.get_shape()
conv4b = conv2d(conv4a,filters=512,kernel=3,stride=1,pad='same',name = 'conv4b')
conv4b.get_shape()
drop4 = dropout(conv4b, drop_rate)
drop4.get_shape()
pool4 = max_pool(drop4,n=2,stride=2,pad='SAME')
pool4.get_shape()
# -------------------------- Contextual input path ----------------------------
conv1a_2 = conv2d(x_contextual,filters=64,kernel=3,stride=1,pad='same',name = 'conv1a2')
conv1b_2 = conv2d(conv1a_2,filters=64,kernel=3,stride=1,pad='same',name = 'conv1b2')
#drop1_2 = tf.nn.dropout(conv1b_2, keep_prob=drop_rate)
pool1_2 = max_pool(conv1b_2,n=2,stride=2,pad='SAME')
conv2a_2 = conv2d(pool1_2,filters=128,kernel=3,stride=1,pad='same',name = 'conv2a2')
conv2b_2 = conv2d(conv2a_2,filters=128,kernel=3,stride=1,pad='same',name = 'conv2b2')
drop2_2 = dropout(conv2b_2, drop_rate)
pool2_2 = max_pool(drop2_2,n=2,stride=2,pad='SAME')
conv3a_2 = conv2d(pool2_2,filters=256,kernel=3,stride=1,pad='same',name = 'conv3a2')
conv3b_2 = conv2d(conv3a_2,filters=256,kernel=3,stride=1,pad='same',name = 'conv3b2')
drop3_2 = dropout(conv3b_2, drop_rate)
pool3_2 = max_pool(drop3_2,n=2,stride=2,pad='SAME')
conv4a_2 = conv2d(pool3_2,filters=512,kernel=3,stride=1,pad='same',name = 'conv4a2')
conv4b_2 = conv2d(conv4a_2,filters=512,kernel=3,stride=1,pad='same',name = 'conv4b2')
drop4_2 = dropout(conv4b_2, drop_rate)
pool4_2 = max_pool(drop4_2,n=2,stride=2,pad='SAME')
# ---------------------------- Expansive path ---------------------------------
combx = concat(pool4,pool4_2,axis=3)
conv5a = conv2d(combx,filters=1024,kernel=3,stride=1,pad='same',name = 'conv5a')
conv5a.get_shape()
conv5b = conv2d(conv5a,filters=1024,kernel=3,stride=1,pad='same',name = 'conv5b')
conv5b.get_shape()
drop5 = dropout(conv5b, drop_rate)
drop5.get_shape()
up6a = transpose(drop5,filters=512,kernel=2,stride=2,pad='same',name='up6a')
up6a.get_shape()
up6b = concat(up6a,conv4b,axis=3)
up6b.get_shape()
conv7a = conv2d(up6b,filters=512,kernel=3,stride=1,pad='same',name = 'conv7a')
conv7a.get_shape()
conv7b = conv2d(conv7a,filters=512,kernel=3,stride=1,pad='same',name = 'conv7b')
conv7b.get_shape()
drop7 = dropout(conv7b, drop_rate)
drop7.get_shape()
up7a = transpose(drop7,filters=256,kernel=2,stride=2,pad='same',name='up7a')
up7a.get_shape()
up7b = concat(up7a,conv3b,axis=3)
up7b.get_shape()
conv8a = conv2d(up7b,filters=256,kernel=3,stride=1,pad='same',name = 'conv7a')
conv8a.get_shape()
conv8b = conv2d(conv8a,filters=256,kernel=3,stride=1,pad='same',name = 'conv7b')
conv8b.get_shape()
drop8 = dropout(conv8b, drop_rate)
drop8.get_shape()
up8a = transpose(drop8,filters=128,kernel=2,stride=2,pad='same',name='up7a')
up8a.get_shape()
up8b = concat(up8a,conv2b,axis=3)
up8b.get_shape()
conv9a = conv2d(up8b,filters=128,kernel=3,stride=1,pad='same',name = 'conv7a')
conv9a.get_shape()
conv9b = conv2d(conv9a,filters=128,kernel=3,stride=1,pad='same',name = 'conv7b')
conv9b.get_shape()
#drop9 = tf.nn.dropout(conv9b, keep_prob=drop_rate)
#drop9.get_shape()
up9a = transpose(conv9b,filters=64,kernel=2,stride=2,pad='same',name='up7a')
up9a.get_shape()
up9b = concat(up9a,conv1b,axis=3)
up9b.get_shape()
conv10a = conv2d(up9b,filters=64,kernel=3,stride=1,pad='same',name = 'conv7a')
conv10a.get_shape()
conv10b = conv2d(conv10a,filters=64,kernel=3,stride=1,pad='same',name = 'conv7b')
conv10b.get_shape()
output = tf.layers.conv2d(conv10b, n_classes, 1, (1,1),padding ='same',activation=tf.nn.softmax, kernel_initializer=tf.contrib.layers.xavier_initializer(), name = 'output')
output.get_shape()
######################################################################
## ##
## Loading data ##
## ##
######################################################################
filelist_train = natural_sort(glob.glob('WHS/Data/train_segments_*.npz')) # list of file names
x_train = {}
y_train = {}
keys = range(len(filelist_train))
for i in keys:
x_train[i] = np.zeros([imgDim,imgDim,imgDim,1])
y_train[i] = np.zeros([imgDim,imgDim,imgDim,8])
for i in range(len(filelist_train)):
img, lab = create_data(filelist_train[i],'axial')
x_train[i] = img
y_train[i] = lab
#filelist_val = natural_sort(glob.glob('WHS/Data/validation_segments_*.npz')) # list of file names
#x_val = {}
#y_val = {}
#keys = range(len(filelist_val))
#for i in keys:
# x_val[i] = np.zeros([imgDim,imgDim,imgDim,1])
# y_val[i] = np.zeros([imgDim,imgDim,imgDim,8])
#
#for i in range(len(filelist_val)):
# img, lab = create_data(filelist_val[i],'axial')
# x_val[i] = img
# y_val[i] = lab
#
######################################################################
## ##
## Defining the training ##
## ##
######################################################################
# Training-steps (honestly I have no idea what it does...)
global_step = tf.Variable(0,trainable=False)
###############################################################################
## Loss ##
###############################################################################
# Compare the output of the network (output: tensor) with the ground truth (y: tensor/placeholder)
# In this case we use sigmoid cross entropu losss with logits
loss = tf.reduce_mean(keras.losses.categorical_crossentropy(y_true = y, y_pred = output))
correct_prediction = tf.equal(tf.argmax(output, axis=-1), tf.argmax(y, axis=-1))
# averaging the one-hot encoded vector
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#dice = dice_coef(y, output,smooth=1)
# Create contextual output:
pred = tf.argmax(tf.nn.softmax(output[0,:,:,:]),axis=-1)
predict = tf.one_hot(pred,8)
context = tf.concat([x[0,:,:,:],predict],axis=-1)
###############################################################################
## Optimizer ##
###############################################################################
opt = tf.train.AdamOptimizer(lr,beta1,beta2,epsilon)
###############################################################################
## Minimizer ##
###############################################################################
train_adam = opt.minimize(loss, global_step)
###############################################################################
## Initializer ##
###############################################################################
# Initializes all variables in the graph
init = tf.global_variables_initializer()
######################################################################
## ##
## Start training ##
## ##
######################################################################
# Initialize saving of the network parameters:
saver = tf.train.Saver()
######################## Start training Session ###########################
start_time = time()
#valid_loss, valid_accuracy = [], []
train_loss, train_accuracy = [], []
c = np.zeros([imgDim+1,imgDim,imgDim,9])
predictions = {}
keys = range(len(filelist_train))
for i in keys:
predictions[i] = c
#predictions_val = {}
#keys = range(len(filelist_val))
#for i in keys:
# predictions_val[i] = c
index_volumeID = np.repeat(range(len(x_train)),imgDim)
index_imageID = np.tile(range(imgDim),len(x_train))
index_comb = np.vstack((index_volumeID,index_imageID)).T
index_shuffle = shuffle(index_comb)
with tf.Session() as sess:
# Initialize
t_start = time()
sess.run(init)
# Trainingsloop
for epoch in range(nEpochs):
t_epoch_start = time()
print('========Training Epoch: ', (epoch + 1))
iter_by_epoch = len(index_shuffle)
for i in range(iter_by_epoch):
t_iter_start = time()
x_batch = np.expand_dims(x_train[index_shuffle[i,0]][index_shuffle[i,1],:,:,:], axis=0)
x_batch_context = np.expand_dims(predictions[index_shuffle[i,0]][index_shuffle[i,1],:,:,:], axis=0)
y_batch = np.expand_dims(y_train[index_shuffle[i,0]][index_shuffle[i,1],:,:,:], axis=0)
_,_loss,_acc,pred_out = sess.run([train_adam, loss, accuracy,context], feed_dict={x: x_batch, x_contextual: x_batch_context, y: y_batch, drop_rate: 0.5})
predictions[index_shuffle[i,0]][index_shuffle[i,1]+1,:,:,:] = pred_out
train_loss.append(_loss)
train_accuracy.append(_acc)
# # Validation-step:
# if i==np.max(range(iter_by_epoch)):
# for n in range(len(x_val)):
# for m in range(imgDim):
# x_batch_val = np.expand_dims(x_val[n][m,:,:,:], axis=0)
# y_batch_val = np.expand_dims(y_val[n][m,:,:,:], axis=0)
# x_context_val = np.expand_dims(predictions_val[n][m,:,:,:], axis=0)
# acc_val, loss_val,out_context = sess.run([accuracy,loss,context], feed_dict={x: x_batch_val, x_contextual: x_context_val, y: y_batch_val, drop_rate: 1.0})
# predictions_val[n][m+1,:,:,:] = pred_out
# valid_loss.append(loss_val)
# valid_accuracy.append(acc_val)
#
t_epoch_finish = time()
print("Epoch:", (epoch + 1), ' avg_loss= ', "{:.9f}".format(np.mean(train_loss)), ' avg_acc= ', "{:.9f}".format(np.mean(train_accuracy)),' time_epoch=', str(t_epoch_finish-t_epoch_start))
# print("Validation:", (epoch + 1), ' avg_loss= ', "{:.9f}".format(np.mean(valid_loss)), ' avg_acc= ', "{:.9f}".format(np.mean(valid_accuracy)))
t_end = time()
saver.save(sess,"WHS/Results/segmentation/model_axial/model.ckpt")
np.save('WHS/Results/train_hist/segmentation/train_loss_axial',train_loss)
np.save('WHS/Results/train_hist/segmentation/train_acc_axial',train_accuracy)
# np.save('WHS/Results/train_hist/segmentation/valid_loss_axial',valid_loss)
# np.save('WHS/Results/train_hist/segmentation/valid_acc_axial',valid_accuracy)
print('Training Done! Total time:' + str(t_end - t_start))#!/usr/bin/env python3