--- a +++ b/unet.py @@ -0,0 +1,173 @@ +""" +This code is to build and train 2D U-Net +""" +import numpy as np +import sys +import subprocess +import argparse +import os + +from keras.models import Model +from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, ZeroPadding2D, add +from keras.optimizers import Adam, SGD +from keras.callbacks import ModelCheckpoint, CSVLogger +from keras import backend as K +from keras import losses + +import tensorflow as tf +import matplotlib.pyplot as plt +import pandas as pd +import csv + +from utils import * +from data import load_train_data + +K.set_image_data_format('channels_last') # Tensorflow dimension ordering + +# ----- paths setting ----- +data_path = sys.argv[1] + "/" +model_path = data_path + "models/" +log_path = data_path + "logs/" + + +# ----- params for training and testing ----- +batch_size = 1 +cur_fold = sys.argv[2] +plane = sys.argv[3] +epoch = int(sys.argv[4]) +init_lr = float(sys.argv[5]) + + +# ----- Dice Coefficient and cost function for training ----- +smooth = 1. + +def dice_coef(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + +def dice_coef_loss(y_true, y_pred): + return -dice_coef(y_true, y_pred) + + +def get_unet((img_rows, img_cols), flt=64, pool_size=(2, 2, 2), init_lr=1.0e-5): + """build and compile Neural Network""" + + print "start building NN" + inputs = Input((img_rows, img_cols, 1)) + + conv1 = Conv2D(flt, (3, 3), activation='relu', padding='same')(inputs) + conv1 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + + conv2 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(pool1) + conv2 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + + conv3 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(pool2) + conv3 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) + + conv4 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(pool3) + conv4 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(conv4) + pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) + + conv5 = Conv2D(flt*16, (3, 3), activation='relu', padding='same')(pool4) + conv5 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(conv5) + + up6 = concatenate([Conv2DTranspose(flt*8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3) + conv6 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(up6) + conv6 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(conv6) + + up7 = concatenate([Conv2DTranspose(flt*4, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3) + conv7 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(up7) + conv7 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(conv7) + + up8 = concatenate([Conv2DTranspose(flt*2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3) + conv8 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(up8) + conv8 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv8) + + up9 = concatenate([Conv2DTranspose(flt, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) + conv9 = Conv2D(flt, (3, 3), activation='relu', padding='same')(up9) + conv9 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv9) + + conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9) + + model = Model(inputs=[inputs], outputs=[conv10]) + + model.compile(optimizer=Adam(lr=init_lr), loss=dice_coef_loss, metrics=[dice_coef]) + + return model + + +def train(fold, plane, batch_size, nb_epoch,init_lr): + """ + train an Unet model with data from load_train_data() + + Parameters + ---------- + fold : string + which fold is experimenting in 4-fold. It should be one of 0/1/2/3 + + plane : char + which plane is experimenting. It is from 'X'/'Y'/'Z' + + batch_size : int + size of mini-batch + + nb_epoch : int + number of epochs to train NN + + init_lr : float + initial learning rate + """ + + print "number of epoch: ", nb_epoch + print "learning rate: ", init_lr + + # --------------------- load and preprocess training data ----------------- + print '-'*80 + print ' Loading and preprocessing train data...' + print '-'*80 + + imgs_train, imgs_mask_train = load_train_data(fold, plane) + + imgs_row = imgs_train.shape[1] + imgs_col = imgs_train.shape[2] + + imgs_train = preprocess(imgs_train) + imgs_mask_train = preprocess(imgs_mask_train) + + imgs_train = imgs_train.astype('float32') + imgs_mask_train = imgs_mask_train.astype('float32') + + # ---------------------- Create, compile, and train model ------------------------ + print '-'*80 + print ' Creating and compiling model...' + print '-'*80 + + model = get_unet((imgs_row, imgs_col), pool_size=(2, 2, 2), init_lr=init_lr) + print model.summary() + + print '-'*80 + print ' Fitting model...' + print '-'*80 + + ver = 'unet_fd%s_%s_ep%s_lr%s.csv'%(cur_fold, plane, epoch, init_lr) + csv_logger = CSVLogger(log_path + ver) + model_checkpoint = ModelCheckpoint(model_path + ver + ".h5", + monitor='loss', + save_best_only=False, + period=10) + + history = model.fit(imgs_train, imgs_mask_train, + batch_size= batch_size, epochs= nb_epoch, verbose=1, shuffle=True, + callbacks=[model_checkpoint, csv_logger]) + + +if __name__ == "__main__": + + train(cur_fold, plane, batch_size, epoch, init_lr) + + print "training done"