--- a +++ b/uNet_Functional.py @@ -0,0 +1,617 @@ +# %% importing packages + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +from tensorflow.keras import mixed_precision +from tensorflow.python.ops.numpy_ops import np_config +np_config.enable_numpy_behavior() +from skimage import measure +import cv2 as cv +import os +import tqdm +import matplotlib.pyplot as plt +import gc + + +# %% Citations +############################################################# +############################################################# +# https://www.tensorflow.org/guide/keras/functional +# https://www.tensorflow.org/tutorials/customization/custom_layers +# https://keras.io/examples/keras_recipes/tfrecord/ +# https://arxiv.org/abs/1505.04597 +# https://www.tensorflow.org/guide/gpu + +# Defining Functions +############################################################# +############################################################# + +def parse_tf_elements(element): + '''This function is the mapper function for retrieving examples from the + tfrecord''' + + # create placeholders for all the features in each example + data = { + 'height' : tf.io.FixedLenFeature([],tf.int64), + 'width' : tf.io.FixedLenFeature([],tf.int64), + 'raw_image' : tf.io.FixedLenFeature([],tf.string), + 'raw_seg' : tf.io.FixedLenFeature([],tf.string), + 'bbox_x' : tf.io.VarLenFeature(tf.float32), + 'bbox_y' : tf.io.VarLenFeature(tf.float32), + 'bbox_height' : tf.io.VarLenFeature(tf.float32), + 'bbox_width' : tf.io.VarLenFeature(tf.float32) + } + + # pull out the current example + content = tf.io.parse_single_example(element, data) + + # pull out each feature from the example + height = content['height'] + width = content['width'] + raw_seg = content['raw_seg'] + raw_image = content['raw_image'] + bbox_x = content['bbox_x'] + bbox_y = content['bbox_y'] + bbox_height = content['bbox_height'] + bbox_width = content['bbox_width'] + + # convert the images to uint8, and reshape them accordingly + image = tf.io.parse_tensor(raw_image, out_type=tf.uint8) + image = tf.reshape(image,shape=[height,width,3]) + segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8) + segmentation = tf.reshape(segmentation,shape=[height,width,1]) + one_hot_seg = tf.one_hot(tf.squeeze(segmentation-1),4,axis=-1) + + # there currently is a bug with returning the bbox, but isn't necessary + # to fix for creating the initial uNet for segmentation exploration + + # bbox = [bbox_x,bbox_y,bbox_height,bbox_width] + + return(image,one_hot_seg) + +############################################################# + +def load_dataset(file_names): + '''Receives a list of file names from a folder that contains tfrecord files + compiled previously. Takes these names and creates a tensorflow dataset + from them.''' + + ignore_order = tf.data.Options() + ignore_order.experimental_deterministic = False + dataset = tf.data.TFRecordDataset(file_names) + + # you can shard the dataset if you like to reduce the size when necessary + # dataset = dataset.shard(num_shards=2,index=1) + + # order in the file names doesn't really matter, so ignoring it + dataset = dataset.with_options(ignore_order) + + # mapping the dataset using the parse_tf_elements function defined earlier + dataset = dataset.map(parse_tf_elements,num_parallel_calls=1) + + return(dataset) + +############################################################# + +def get_dataset(file_names,batch_size): + '''Receives a list of file names of tfrecord shards from a dataset as well + as a batch size for the dataset.''' + + # uses the load_dataset function to retrieve the files and put them into a + # dataset. + dataset = load_dataset(file_names) + + # creates a shuffle buffer of 1000. Number was arbitrarily chosen, feel free + # to alter as fits your hardware. + dataset = dataset.shuffle(1000) + + # adding the batch size to the dataset + dataset = dataset.batch(batch_size=batch_size) + + return(dataset) + +############################################################# + +def weighted_cce_loss(y_true,y_pred): + '''Yes, this function essentially does what the "fit" argument + "class_weight" does when training a network. I had to create this + separate custom loss function because aparently when using tfrecord files + for reading your dataset a check is performed comparing the input, ground + truth, and weights values to each other. However, a comparison between + the empty None that is passed during the build call of the model and the + weight array/dictionary returns an error. Thus, here is a custom loss + function that applies a weighting to the different classes based on the + distribution of the classes within the entire dataset. For thoroughness' + sake future iteration of the dataset will only base the weights on the + dataset used for training, not the whole dataset.''' + + # weights for each class, as background, connective, muscle, and vasculature + weights = [28.78661087,3.60830475,1.63037567,14.44688883] + + # create a weight for each of the images in the current batch (because the + # weighting for categorical crossentropy needs one per input) + for idx,weight in enumerate(weights): + # making the input a numpy array and not an eager tensor to allow for + # binary index masking. + current_weights = np.asarray(tf.argmax(y_true,axis=-1)).copy().astype( + np.float64) + # create a mask for the current class that then becomes the value of the + # weight. This is then passed to the loss function to apply to each + # pixel. + mask = current_weights==idx + current_weights[mask] = weight + + cce = tf.keras.losses.CategoricalCrossentropy() + cce_loss = cce(y_true,y_pred,current_weights) + + return(cce_loss) + +############################################################# +############################################################# + +# %% Setting up the GPU, and setting memory growth to true so that it is easier +# to see how much memory the training process is taking up exactly. This code is +# from a tensorflow tutorial. + +gpus = tf.config.list_physical_devices('GPU') +if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.list_logical_devices('GPU') + + print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") + except RuntimeError as e: + print(e) + +# use this to set mixed precision for higher efficiency later if you would like +# mixed_precision.set_global_policy('mixed_float16') + + +# %% setting up datasets and building model + +# directory where the dataset shards are stored +shard_dataset_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_ScaleFactor2' + +os.chdir(shard_dataset_directory) + +# only get the file names that follow the shard naming convention +file_names = tf.io.gfile.glob(shard_dataset_directory + \ + "/shard_*_of_*.tfrecords") + +# first 70% of names go to the training dataset. Following 20% go to the val +# dataset, followed by last 10% go to the testing dataset. +val_split_idx = int(0.7*len(file_names)) +test_split_idx = int(0.9*len(file_names)) + +# separate the file names out +train_files, val_files, test_files = file_names[:val_split_idx],\ + file_names[val_split_idx:test_split_idx],\ + file_names[test_split_idx:] + +# create the datasets. Because of how batches are run for training, we set +# the dataset to repeat() because the batches and epochs are altered from +# standard practice to fit on graphics cards and provide more meaningful and +# frequent updates to the console. +training_dataset = get_dataset(train_files,batch_size=15) +training_dataset = training_dataset.repeat() +validation_dataset = get_dataset(val_files,batch_size = 5) +# testing has a batch size of 1 to facilitate visualization of predictions +testing_dataset = get_dataset(test_files,batch_size=1) + +# %% Putting together the network + +# filter multiplier provided creates largest filter depth of 256 with a +# multiplier of 8. +filter_multiplier = 8 +# encoder convolution parameters +enc_kernel = (3,3) +enc_strides = (1,1) + +# encoder max-pooling parameters +enc_pool_size = (2,2) +enc_pool_strides = (2,2) + +# setting the input size +net_input = keras.Input(shape=(512,512,3),name='original_image') + +################## Encoder ################## +# encoder, block 1 + +# including the image normalization within the network for easier image +# processing during inference +normalized = layers.Normalization()(net_input) + +enc1 = layers.Conv2D(filters=2*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc1_conv1')(normalized) + +enc1 = tf.keras.layers.BatchNormalization()(enc1) +enc1 = layers.ReLU()(enc1) + +enc1 = layers.Conv2D(filters=2*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc1_conv2')(enc1) + +enc1 = tf.keras.layers.BatchNormalization()(enc1) +enc1 = layers.ReLU()(enc1) + +enc1_pool = layers.MaxPooling2D(pool_size=enc_pool_size, + strides=enc_pool_strides, + padding='same', + name='enc1_pool')(enc1) + + +# encoder, block 2 +enc2 = layers.Conv2D(filters=4*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc2_conv1')(enc1_pool) + +enc2 = tf.keras.layers.BatchNormalization()(enc2) +enc2 = layers.ReLU()(enc2) + +enc2 = layers.Conv2D(filters=4*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc2_conv2')(enc2) + +enc2 = tf.keras.layers.BatchNormalization()(enc2) +enc2 = layers.ReLU()(enc2) + +enc2_pool = layers.MaxPooling2D(pool_size=enc_pool_size, + strides=enc_pool_strides, + padding='same', + name='enc2_pool')(enc2) + + +# encoder, block 3 +enc3 = layers.Conv2D(filters=8*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc3_conv1')(enc2_pool) + +enc3 = tf.keras.layers.BatchNormalization()(enc3) +enc3 = layers.ReLU()(enc3) + +enc3 = layers.Conv2D(filters=8*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc3_conv2')(enc3) + +enc3 = tf.keras.layers.BatchNormalization()(enc3) +enc3 = layers.ReLU()(enc3) + +enc3_pool = layers.MaxPooling2D(pool_size=enc_pool_size, + strides=enc_pool_strides, + padding='same', + name='enc3_pool')(enc3) + +# encoder, block 4 +enc4 = layers.Conv2D(filters=16*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc4_conv1')(enc3_pool) + +enc4 = tf.keras.layers.BatchNormalization()(enc4) +enc4 = layers.ReLU()(enc4) + +enc4 = layers.Conv2D(filters=16*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc4_conv2')(enc4) + +enc4 = tf.keras.layers.BatchNormalization()(enc4) +enc4 = layers.ReLU()(enc4) + +enc4_pool = layers.MaxPooling2D(pool_size=enc_pool_size, + strides=enc_pool_strides, + padding='same', + name='enc4_pool')(enc4) + + +# encoder, block 5 +enc5 = layers.Conv2D(filters=32*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc5_conv1')(enc4_pool) + +enc5 = tf.keras.layers.BatchNormalization()(enc5) +enc5 = layers.ReLU()(enc5) + +enc5 = layers.Conv2D(filters=32*filter_multiplier, + kernel_size=enc_kernel, + strides=enc_strides, + padding='same', + name='enc5_conv2')(enc5) + +enc5 = tf.keras.layers.BatchNormalization()(enc5) +enc5 = layers.ReLU()(enc5) + +################## Decoder ################## + +# decoder upconv parameters +dec_upconv_kernel = (2,2) +dec_upconv_stride = (2,2) + +# decoder forward convolution parameters +dec_conv_stride = (1,1) +dec_conv_kernel = (3,3) + +# Decoder, block 4 +dec4_up = layers.Conv2DTranspose(filters=16*filter_multiplier, + kernel_size=dec_upconv_kernel, + strides=dec_upconv_stride, + padding='same', + name='dec4_upconv')(enc5) + +dec4_conc = layers.concatenate([dec4_up,enc4],axis=-1) + +dec4 = layers.Conv2D(filters=16*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec4_conv1')(dec4_conc) + +dec4 = tf.keras.layers.BatchNormalization()(dec4) +dec4 = layers.ReLU()(dec4) + +dec4 = layers.Conv2D(filters=16*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec4_conv2')(dec4) + +dec4 = tf.keras.layers.BatchNormalization()(dec4) +dec4 = layers.ReLU()(dec4) + + +# Decoder, block 3 +dec3_up = layers.Conv2DTranspose(filters=8*filter_multiplier, + kernel_size=dec_upconv_kernel, + strides=dec_upconv_stride, + padding='same', + name='dec3_upconv')(dec4) + +dec3_conc = layers.concatenate([dec3_up,enc3],axis=-1) + +dec3 = layers.Conv2D(filters=8*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec3_conv1')(dec3_conc) + +dec3 = tf.keras.layers.BatchNormalization()(dec3) +dec3 = layers.ReLU()(dec3) + +dec3 = layers.Conv2D(filters=8*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec3_conv2')(dec3) + +dec3 = tf.keras.layers.BatchNormalization()(dec3) +dec3 = layers.ReLU()(dec3) + + +# Decoder, block 2 +dec2_up = layers.Conv2DTranspose(filters=4*filter_multiplier, + kernel_size=dec_upconv_kernel, + strides=dec_upconv_stride, + padding='same', + name='dec2_upconv')(dec3) + +dec2_conc = layers.concatenate([dec2_up,enc2],axis=-1) + +dec2 = layers.Conv2D(filters=4*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec2_conv1')(dec2_conc) + +dec2 = tf.keras.layers.BatchNormalization()(dec2) +dec2 = layers.ReLU()(dec2) + +dec2 = layers.Conv2D(filters=4*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec2_conv2')(dec2) + +dec2 = tf.keras.layers.BatchNormalization()(dec2) +dec2 = layers.ReLU()(dec2) + + +# Decoder, block 1 +dec1_up = layers.Conv2DTranspose(filters=2*filter_multiplier, + kernel_size=dec_upconv_kernel, + strides=dec_upconv_stride, + padding='same', + name='dec1_upconv')(dec2) + +dec1_conc = layers.concatenate([dec1_up,enc1],axis=-1) + +dec1 = layers.Conv2D(filters=2*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec1_conv1')(dec1_conc) + +dec1 = tf.keras.layers.BatchNormalization()(dec1) +dec1 = layers.ReLU()(dec1) + +dec1 = layers.Conv2D(filters=2*filter_multiplier, + kernel_size=dec_conv_kernel, + strides=dec_conv_stride, + padding='same', + name='dec1_conv2')(dec1) + +dec1 = tf.keras.layers.BatchNormalization()(dec1) +dec1 = layers.ReLU()(dec1) + +conv_seg = layers.Conv2D(filters=4, + kernel_size=(1,1), + name='conv_feature_map')(dec1) + +prob_dist = layers.Softmax(dtype='float32')(conv_seg) + +unet = keras.Model(inputs=net_input,outputs=prob_dist,name='uNet') + +unet.summary() + +# %% setting up training + +cce = tf.keras.losses.CategoricalCrossentropy() + +# running network eagerly because it allows us to use convert a tensor to a +# numpy array to help with the weighted loss calculation. +unet.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), + loss=weighted_cce_loss, + run_eagerly=True, + metrics=[tf.keras.metrics.Precision(name='precision'), + tf.keras.metrics.Recall(name='recall')] +) + +# %% +reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_recall', + mode='max', + factor=0.8, + patience=3, + min_lr=0.00001, + verbose=True) + +checkpoint_cb = tf.keras.callbacks.ModelCheckpoint('unet_seg_subclassed.h5', + save_best_only=True, + save_weights_only=True, + monitor='val_recall', + mode='max', + verbose=True) + +early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=8, + monitor='val_recall', + mode='max', + restore_best_weights=True, + verbose=True) + +num_steps = 150 + +history = unet.fit(training_dataset, + epochs=20, + steps_per_epoch=num_steps, + validation_data=validation_dataset, + callbacks=[checkpoint_cb, + early_stopping_cb, + reduce_lr]) + +# %% +# evaluate the network after loading the weights +unet.load_weights('./unet_seg_functional.h5') +results = unet.evaluate(testing_dataset) + +# %% +# extracting loss vs epoch +loss = history.history['loss'] +val_loss = history.history['val_loss'] +# extracting precision vs epoch +precision = history.history['precision'] +val_precision = history.history['val_precision'] +# extracting recall vs epoch +recall = history.history['recall'] +val_recall = history.history['val_recall'] + +epochs = range(len(loss)) + +figs, axes = plt.subplots(3,1) + +# plotting loss and validation loss +axes[0].plot(epochs,loss) +axes[0].plot(epochs,val_loss) +axes[0].legend(['loss','val_loss']) +axes[0].set(xlabel='epochs',ylabel='crossentropy loss') + +# plotting precision and validation precision +axes[1].plot(epochs,precision) +axes[1].plot(epochs,val_precision) +axes[1].legend(['precision','val_precision']) +axes[1].set(xlabel='epochs',ylabel='precision') + +# plotting recall validation recall +axes[2].plot(epochs,recall) +axes[2].plot(epochs,val_recall) +axes[2].legend(['recall','val_recall']) +axes[2].set(xlabel='epochs',ylabel='recall') + + +# %% exploring the predictions to better understand what the network is doing + +images = [] +gt = [] +predictions = [] + +# taking out 10 of the next samples from the testing dataset and iterating +# through them +for sample in testing_dataset.take(10): + # make sure it is producing the correct dimensions + print(sample[0].shape) + # take the image and convert it back to RGB, store in list + image = sample[0] + image = cv.cvtColor(np.squeeze(np.asarray(image).copy()),cv.COLOR_BGR2RGB) + images.append(image) + # extract the ground truth and store in list + ground_truth = sample[1] + gt.append(ground_truth) + # perform inference + out = unet.predict(sample[0]) + predictions.append(out) + # show the original input image + plt.imshow(image) + plt.show() + # flatten the ground truth from one-hot encoded along the last axis, and + # show the resulting image + squeezed_gt = tf.argmax(ground_truth,axis=-1) + squeezed_prediction = tf.argmax(out,axis=-1) + plt.imshow(squeezed_gt[0,:,:]) + # print the number of classes in this tile + print(np.unique(squeezed_gt)) + plt.show() + # show the flattened predictions + plt.imshow(squeezed_prediction[0,:,:]) + print(np.unique(squeezed_prediction)) + plt.show() + +# %% +# select one of the images cycled through above to investigate furtehr +image_to_investigate = 2 + +# show the original image +plt.imshow(images[image_to_investigate]) +plt.show() + +# show the ground truth for this tile +squeezed_gt = tf.argmax(gt[image_to_investigate],axis=-1) +plt.imshow(squeezed_gt[0,:,:]) +# print the number of unique classes in the ground truth +print(np.unique(squeezed_gt)) +plt.show() + # flatten the prediction and show the probability distribution +squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1) +plt.imshow(predictions[image_to_investigate][0,:,:,3]) +plt.show() +# show the flattened image +plt.imshow(squeezed_prediction[0,:,:]) +print(np.unique(squeezed_prediction)) +plt.show()