Diff of /code/test-mask.py [000000] .. [ebf7be]

Switch to side-by-side view

--- a
+++ b/code/test-mask.py
@@ -0,0 +1,115 @@
+"""
+    Purpose: train a machine learning segmenter that can segment out the nodules on a given 2D patient CT scan slice
+    Note:
+    - this will train from scratch, with no preloaded weights
+    - weights are saved to unet.hdf5 in the specified output folder
+"""
+
+from __future__ import print_function
+
+import numpy as np
+from keras.models import Model
+from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D
+from keras.optimizers import Adam
+from keras.optimizers import SGD
+from keras.callbacks import ModelCheckpoint, LearningRateScheduler
+from keras import backend as K
+import matplotlib.pyplot as plt
+
+WORKING_PATH = "/home/marshallee/Documents/lung/subset0/"
+IMG_ROWS = 512
+IMG_COLS = 512
+
+SMOOTH = 1.
+
+K.set_image_dim_ordering('th')  # Theano dimension ordering in this code
+
+def dice_coef(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (2. * intersection + SMOOTH) / (K.sum(y_true_f) + K.sum(y_pred_f) + SMOOTH)
+
+def dice_coef_loss(y_true, y_pred):
+    return -dice_coef(y_true, y_pred)
+
+def dice_coef_np(y_true,y_pred):
+    y_true_f = y_true.flatten()
+    y_pred_f = y_pred.flatten()
+    intersection = np.sum(y_true_f * y_pred_f)
+    return (2. * intersection + SMOOTH) / (np.sum(y_true_f) + np.sum(y_pred_f) + SMOOTH)
+
+
+def get_unet():
+    """
+        U-net architecture
+    """
+    inputs = Input((1,IMG_ROWS, IMG_COLS))
+    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)
+    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
+    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
+    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
+    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
+    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
+    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)
+    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)
+    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
+
+    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)
+    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)
+
+    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
+    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)
+    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)
+
+    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
+    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)
+    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)
+
+    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
+    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)
+    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)
+
+    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
+    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)
+    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)
+
+    conv10 = Convolution2D(1, 1, 1, activation='sigmoid')(conv9)
+
+    model = Model(input=inputs, output=conv10)
+
+    model.compile(optimizer=Adam(lr=1.0e-5), loss=dice_coef_loss, metrics=[dice_coef])
+
+    return model
+
+
+def test():
+    imgs_test = np.load(WORKING_PATH+"trainImages.npy").astype(np.float32)
+    imgs_mask_test_true = np.load(WORKING_PATH+"trainMasks.npy").astype(np.float32)
+    num = len(imgs_test)
+
+    mean_test = np.mean(imgs_test)  # mean for data centering
+    std_test = np.std(imgs_test)  # std for data normalization
+    imgs_test -= mean_test  # images should already be standardized, but just in case
+    imgs_test /= std_test
+    model = get_unet()
+    model.load_weights('unet2.hdf5')  
+    predMask = model.predict(imgs_test)
+    print('pred shape: ', predMask.shape)
+
+    fig,ax = plt.subplots(2,2,figsize=[8,8])
+    ax[0,0].imshow(imgs_test[i][0],cmap='gray')
+    ax[0,1].imshow(predMask[i][0],cmap='gray')
+    ax[1,0].imshow(imgs_test[i][0]*predMask[i][0],cmap='gray')
+    plt.show()
+
+
+    
+if __name__ == '__main__':
+    test()
\ No newline at end of file