--- a +++ b/helper.py @@ -0,0 +1,153 @@ +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np +import random +from scipy import ndimage +import tensorflow as tf +from tensorflow.keras.layers import Dense, Conv3D, MaxPool3D, BatchNormalization, GlobalAveragePooling3D, Dropout + + +def read_scan(filepath): + """Read and load volume""" + # Read file + scan = nib.load(filepath) + # Get raw data + scan = scan.get_fdata() + return scan + + +def normalize(volume): + """Normalize the volume""" + min = -1000 + max = 400 + volume[volume < min] = min + volume[volume > max] = max + volume = (volume - min) / (max - min) + volume = volume.astype("float32") + return volume + + +def resize_volume(img): + """Resize across z-axis""" + # Set the desired depth + desired_depth = 64 + desired_width = 128 + desired_height = 128 + # Get current depth + current_depth = img.shape[-1] + current_width = img.shape[0] + current_height = img.shape[1] + # Compute depth factor + depth = current_depth / desired_depth + width = current_width / desired_width + height = current_height / desired_height + depth_factor = 1 / depth + width_factor = 1 / width + height_factor = 1 / height + # Rotate + img = ndimage.rotate(img, 90, reshape=False) + # Resize across z-axis + img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1) + return img + + +def process_scan(path): + """Read and resize volume""" + # Read scan + volume = read_scan(path) + # Normalize + volume = normalize(volume) + # Resize width, height and depth + volume = resize_volume(volume) + return volume + +@tf.function +def rotate(volume): + """Rotate the volume by a few degrees""" + + def scipy_rotate(volume): + # define some rotation angles + angles = [-20, -10, -5, 5, 10, 20] + # pick angles at random + angle = random.choice(angles) + # rotate volume + volume = ndimage.rotate(volume, angle, reshape=False) + volume[volume < 0] = 0 + volume[volume > 1] = 1 + return volume + + augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32) + return augmented_volume + + +def train_preprocessing(volume, label): + """Process training data by rotating and adding a channel.""" + # Rotate volume + volume = rotate(volume) + volume = tf.expand_dims(volume, axis=3) + return volume, label + + +def validation_preprocessing(volume, label): + """Process validation data by only adding a channel.""" + volume = tf.expand_dims(volume, axis=3) + return volume, label + + + + +def plot_slices(num_rows, num_columns, width, height, data): + """Plot a montage of 20 CT slices""" + data = np.rot90(np.array(data)) + data = np.transpose(data) + data = np.reshape(data, (num_rows, num_columns, width, height)) + rows_data, columns_data = data.shape[0], data.shape[1] + heights = [slc[0].shape[0] for slc in data] + widths = [slc.shape[1] for slc in data[0]] + fig_width = 12.0 + fig_height = fig_width * sum(heights) / sum(widths) + f, axarr = plt.subplots( + rows_data, + columns_data, + figsize=(fig_width, fig_height), + gridspec_kw={"height_ratios": heights}, + ) + for i in range(rows_data): + for j in range(columns_data): + axarr[i, j].imshow(data[i][j], cmap="gray") + axarr[i, j].axis("off") + plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) + plt.show() + + +def build_model(width=128, height=128, depth=64): + """Build a 3D convolutional neural network model.""" + + inputs = tf.keras.Input((width, height, depth, 1)) + + x = Conv3D(filters=64, kernel_size=3, activation="relu")(inputs) + x = MaxPool3D(pool_size=2)(x) + x = BatchNormalization()(x) + + x = Conv3D(filters=64, kernel_size=3, activation="relu")(x) + x = MaxPool3D(pool_size=2)(x) + x = BatchNormalization()(x) + + x = Conv3D(filters=128, kernel_size=3, activation="relu")(x) + x = MaxPool3D(pool_size=2)(x) + x = BatchNormalization()(x) + + x = Conv3D(filters=256, kernel_size=3, activation="relu")(x) + x = MaxPool3D(pool_size=2)(x) + x = BatchNormalization()(x) + + x = GlobalAveragePooling3D()(x) + x = Dense(units=512, activation="relu")(x) + x = Dropout(0.3)(x) + + outputs = Dense(units=1, activation="sigmoid")(x) + + # Define the model. + model = tf.keras.Model(inputs, outputs, name="3dctscan") + return model +