--- a +++ b/code/train-mask.py @@ -0,0 +1,125 @@ +""" + Purpose: train a machine learning segmenter that can segment out the nodules on a given 2D patient CT scan slice + Note: + - this will train from scratch, with no preloaded weights + - weights are saved to unet.hdf5 in the specified output folder +""" + +from __future__ import print_function + +import numpy as np +from keras.models import Model +from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D +from keras.optimizers import Adam +from keras.optimizers import SGD +from keras.callbacks import ModelCheckpoint, LearningRateScheduler +from keras import backend as K + +WORKING_PATH = "/home/ubuntu/data/output/" +IMG_ROWS = 512 +IMG_COLS = 512 + +SMOOTH = 1. + +K.set_image_dim_ordering('th') # Theano dimension ordering in this code + +def dice_coef(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection + SMOOTH) / (K.sum(y_true_f) + K.sum(y_pred_f) + SMOOTH) + +def dice_coef_loss(y_true, y_pred): + return -dice_coef(y_true, y_pred) + +def dice_coef_np(y_true,y_pred): + y_true_f = y_true.flatten() + y_pred_f = y_pred.flatten() + intersection = np.sum(y_true_f * y_pred_f) + return (2. * intersection + SMOOTH) / (np.sum(y_true_f) + np.sum(y_pred_f) + SMOOTH) + +def get_unet(): + inputs = Input((1,IMG_ROWS, IMG_COLS)) + conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs) + conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + + conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1) + conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + + conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2) + conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) + + conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3) + conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4) + pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) + + conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4) + conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5) + + up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1) + conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6) + conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6) + + up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1) + conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7) + conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7) + + up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1) + conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8) + conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8) + + up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1) + conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9) + conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9) + + conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9) + + model = Model(input=inputs, output=conv10) + + model.compile(optimizer=Adam(lr=1.0e-5), loss=dice_coef_loss, metrics=[dice_coef]) + + return model + + +def train_and_evaluate(): + print('-'*30) + print('Loading and preprocessing train data...') + print('-'*30) + imgs_train = np.load(WORKING_PATH+"trainImages.npy").astype(np.float32) + imgs_mask_train = np.load(WORKING_PATH+"trainMasks.npy").astype(np.float32) + + imgs_test = np.load(WORKING_PATH+"testImages.npy").astype(np.float32) + imgs_mask_test_true = np.load(WORKING_PATH+"testMasks.npy").astype(np.float32) + + mean = np.mean(imgs_train) # mean for data centering + std = np.std(imgs_train) # std for data normalization + imgs_train -= mean # images should already be standardized, but just in case + imgs_train /= std + + mean_test = np.mean(imgs_test) # mean for data centering + std_test = np.std(imgs_test) # std for data normalization + imgs_test -= mean_test # images should already be standardized, but just in case + imgs_test /= std_test + + print('-'*30) + print('Creating and compiling model...') + print('-'*30) + model = get_unet() + + # Saving weights to unet.hdf5 at checkpoints + model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='loss', save_best_only=True) + + print('-'*30) + print('Fitting model...') + print('-'*30) + model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=20, verbose=1, shuffle=True, callbacks=[model_checkpoint]) + print('Fitting ends...') + + print('start evaluation...') + print('evaluation result is: ', model.eva) + +if __name__ == '__main__': + train_and_predict() \ No newline at end of file