import argparse
import numpy as np
import os
import sys
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply
from keras.layers.normalization import BatchNormalization as bn
from keras.optimizers import RMSprop, Adam
from keras import regularizers, losses, backend as K
from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, ModelCheckpoint, TensorBoard
os.environ['CUDA_VISIBLE_DEVICES']="0"
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_test(y_true, y_pred):
y_true_f = np.array(y_true).flatten()
y_pred_f =np.array(y_pred).flatten()
intersection = np.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
def neg_dice_coef_loss(y_true, y_pred):
return dice_coef(y_true, y_pred)
#define the model
def Comp_U_Net(input_shape,learn_rate=1e-3):
l2_lambda = 0.0002
DropP = 0.3
kernel_size=3
inputs = Input(input_shape,name='ip0')
conv0a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(inputs)
conv0a = bn()(conv0a)
conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv0a)
conv0b = bn()(conv0b)
pool0 = MaxPooling2D(pool_size=(2, 2))(conv0b)
pool0 = Dropout(DropP)(pool0)
conv1a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(pool0)
conv1a = bn()(conv1a)
conv1b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv1a)
conv1b = bn()(conv1b)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1b)
pool1 = Dropout(DropP)(pool1)
conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(pool1)
conv2a = bn()(conv2a)
conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv2a)
conv2b = bn()(conv2b)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2b)
pool2 = Dropout(DropP)(pool2)
conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(pool2)
conv3a = bn()(conv3a)
conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv3a)
conv3b = bn()(conv3b)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3b)
pool3 = Dropout(DropP)(pool3)
conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(pool3)
conv4a = bn()(conv4a)
conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv4a)
conv4b = bn()(conv4b)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4b)
pool4 = Dropout(DropP)(pool4)
conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(pool4)
conv5a = bn()(conv5a)
conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv5a)
conv5b = bn()(conv5b)
up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same')(conv5b), (conv4b)],name='up6', axis=3)
up6 = Dropout(DropP)(up6)
conv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(up6)
conv6a = bn()(conv6a)
conv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv6a)
conv6b = bn()(conv6b)
up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same')(conv6b),(conv3b)],name='up7', axis=3)
up7 = Dropout(DropP)(up7)
#add second output here
conv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(up7)
conv7a = bn()(conv7a)
conv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv7a)
conv7b = bn()(conv7b)
up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same')(conv7b), (conv2b)],name='up8', axis=3)
up8 = Dropout(DropP)(up8)
conv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(up8)
conv8a = bn()(conv8a)
conv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv8a)
conv8b = bn()(conv8b)
up9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(conv8b),(conv1b)],name='up9',axis=3)
conv9a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(up9)
conv9a = bn()(conv9a)
conv9b = Conv2D(12, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv9a)
conv9b = bn()(conv9b)
up10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(conv9b),(conv0b)],name='up10',axis=3)
conv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(up10)
conv10a = bn()(conv10a)
conv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(conv10a)
conv10b = bn()(conv10b)
final_op=Conv2D(1, (1, 1), activation='sigmoid',name='final_op')(conv10b)
#----------------------------------------------------------------------------------------------------------------------------------
#second branch - brain
xup6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same')(conv5b), (conv4b)],name='xup6', axis=3)
xup6 = Dropout(DropP)(xup6)
xconv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xup6)
xconv6a = bn()(xconv6a)
xconv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xconv6a)
xconv6b = bn()(xconv6b)
xup7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same')(xconv6b),(conv3b)],name='xup7', axis=3)
xup7 = Dropout(DropP)(xup7)
xconv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xup7)
xconv7a = bn()(xconv7a)
xconv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xconv7a)
xconv7b = bn()(xconv7b)
xup8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same')(xconv7b),(conv2b)],name='xup8', axis=3)
xup8 = Dropout(DropP)(xup8)
#add third xoutxout here
xconv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xup8)
xconv8a = bn()(xconv8a)
xconv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xconv8a)
xconv8b = bn()(xconv8b)
xup9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(xconv8b), (conv1b)],name='xup9',axis=3)
xup9 = Dropout(DropP)(xup9)
xconv9a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xup9)
xconv9a = bn()(xconv9a)
xconv9b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xconv9a)
xconv9b = bn()(xconv9b)
xup10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(xconv9b), (conv0b)],name='xup10',axis=3)
xup10 = Dropout(DropP)(xup10)
xconv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xup10)
xconv10a = bn()(xconv10a)
xconv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(xconv10a)
xconv10b = bn()(xconv10b)
xfinal_op=Conv2D(1, (1, 1), activation='sigmoid',name='xfinal_op')(xconv10b)
#-----------------------------third branch
#Concatenation fed to the reconstruction layer of all 3
x_u_net_op0=keras.layers.concatenate([final_op,xfinal_op,keras.layers.add([final_op,xfinal_op])],name='res_a')
res_1_conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(x_u_net_op0)
res_1_conv0a = bn()(res_1_conv0a)
res_1_conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0a)
res_1_conv0b = bn()(res_1_conv0b)
res_1_pool0 = MaxPooling2D(pool_size=(2, 2))(res_1_conv0b)
res_1_pool0 = Dropout(DropP)(res_1_pool0)
res_1_conv1a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool0)
res_1_conv1a = bn()(res_1_conv1a)
res_1_conv1b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv1a)
res_1_conv1b = bn()(res_1_conv1b)
res_1_pool1 = MaxPooling2D(pool_size=(2, 2))(res_1_conv1b)
res_1_pool1 = Dropout(DropP)(res_1_pool1)
res_1_conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool1)
res_1_conv2a = bn()(res_1_conv2a)
res_1_conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv2a)
res_1_conv2b = bn()(res_1_conv2b)
res_1_pool2 = MaxPooling2D(pool_size=(2, 2))(res_1_conv2b)
res_1_pool2 = Dropout(DropP)(res_1_pool2)
res_1_conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool2)
res_1_conv3a = bn()(res_1_conv3a)
res_1_conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv3a)
res_1_conv3b = bn()(res_1_conv3b)
res_1_pool3 = MaxPooling2D(pool_size=(2, 2))(res_1_conv3b)
res_1_pool3 = Dropout(DropP)(res_1_pool3)
res_1_conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool3)
res_1_conv4a = bn()(res_1_conv4a)
res_1_conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv4a)
res_1_conv4b = bn()(res_1_conv4b)
res_1_pool4 = MaxPooling2D(pool_size=(2, 2))(res_1_conv4b)
res_1_pool4 = Dropout(DropP)(res_1_pool4)
res_1_conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool4)
res_1_conv5a = bn()(res_1_conv5a)
res_1_conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv5a)
res_1_conv5b = bn()(res_1_conv5b)
res_1_up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same')(res_1_conv5b), (res_1_conv4b)],name='res_1_up6', axis=3)
res_1_up6 = Dropout(DropP)(res_1_up6)
res_1_conv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up6)
res_1_conv6a = bn()(res_1_conv6a)
res_1_conv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv6a)
res_1_conv6b = bn()(res_1_conv6b)
res_1_up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same')(res_1_conv6b),(res_1_conv3b)],name='res_1_up7', axis=3)
res_1_up7 = Dropout(DropP)(res_1_up7)
#add second res_1_output here
res_1_conv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up7)
res_1_conv7a = bn()(res_1_conv7a)
res_1_conv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv7a)
res_1_conv7b = bn()(res_1_conv7b)
res_1_up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same')(res_1_conv7b),(res_1_conv2b)],name='res_1_up8', axis=3)
res_1_up8 = Dropout(DropP)(res_1_up8)
#add third outout here
res_1_conv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up8)
res_1_conv8a = bn()(res_1_conv8a)
res_1_conv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv8a)
res_1_conv8b = bn()(res_1_conv8b)
res_1_up9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(res_1_conv8b), (res_1_conv1b)],name='res_1_up9',axis=3)
res_1_up9 = Dropout(DropP)(res_1_up9)
res_1_conv9a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up9)
res_1_conv9a = bn()(res_1_conv9a)
res_1_conv9b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv9a)
res_1_conv9b = bn()(res_1_conv9b)
res_1_up10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(res_1_conv9b),(res_1_conv0b)],name='res_1_up10',axis=3)
res_1_up10 = Dropout(DropP)(res_1_up10)
res_1_conv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up10)
res_1_conv10a = bn()(res_1_conv10a)
res_1_conv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv10a)
res_1_conv10b = bn()(res_1_conv10b)
res_1_final_op=Conv2D(1, (1, 1), activation='sigmoid',name='res_1_final_op')(res_1_conv10b)
model=Model(inputs=[inputs],outputs=[final_op,
xfinal_op,
res_1_final_op,
])
model.compile(optimizer=keras.optimizers.Adam(lr=1e-5),loss={'final_op':dice_coef_loss,
'xfinal_op':neg_dice_coef_loss,
'res_1_final_op':'mse'})
return model
#----------------------------------------------------Main--------------------------------------------------#
def train_model(data_params, train_params, common_params):
training_data_folder = data_params['data_dir'].rstrip('/')
train_x = training_data_folder + '/' + data_params['train_data_file']
train_y = training_data_folder + '/' + data_params['train_label_file']
model = Comp_U_Net(input_shape=(256,256,1), learn_rate=train_params['learning_rate'])
# print(model.summary())
x_train = np.load(train_x)
y_train = np.load(train_y)
x_train=x_train.reshape(x_train.shape+(1,))
y_train=y_train.reshape(y_train.shape+(1,))
# Log output
print ("Training dwi volume shape: ", x_train.shape)
print ("Training dwi mask volume shape: ", y_train.shape)
view = train_params['principal_axis']
os.makedirs(common_params['log_dir'], exist_ok= True)
csv_logger = CSVLogger(common_params['log_dir'] + '/' + view + '.csv', append=True, separator=';')
# checkpoint
os.makedirs(common_params['save_model_dir'], exist_ok= True)
filepath = common_params['save_model_dir'] + "/weights-" + view + "-improvement-{epoch:02d}.h5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=False, save_weights_only=True)
# Trains the model for a given number of epochs (iterations on a dataset).
history_callback = model.fit([x_train],
[y_train,y_train,y_train],
validation_split=train_params['validation_split'],
batch_size=train_params['train_batch_size'],
epochs=train_params['num_epochs'],
shuffle=train_params['shuffle_data'],
verbose=1,
callbacks=[csv_logger, checkpoint])
import h5py
# serialize model to JSON
model_json = model.to_json()
with open(common_params['save_model_dir'] + "/CompNetBasicModel.json", "w") as json_file:
json_file.write(model_json)
# serialize weights to HDF5
model.save_weights(common_params['save_model_dir'] + "/" + view + "-compnet_final_weight.h5")
print("Saved model to disk location: ", common_params['save_model_dir'])