import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import random
from scipy import ndimage
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv3D, MaxPool3D, BatchNormalization, GlobalAveragePooling3D, Dropout
def read_scan(filepath):
"""Read and load volume"""
# Read file
scan = nib.load(filepath)
# Get raw data
scan = scan.get_fdata()
return scan
def normalize(volume):
"""Normalize the volume"""
min = -1000
max = 400
volume[volume < min] = min
volume[volume > max] = max
volume = (volume - min) / (max - min)
volume = volume.astype("float32")
return volume
def resize_volume(img):
"""Resize across z-axis"""
# Set the desired depth
desired_depth = 64
desired_width = 128
desired_height = 128
# Get current depth
current_depth = img.shape[-1]
current_width = img.shape[0]
current_height = img.shape[1]
# Compute depth factor
depth = current_depth / desired_depth
width = current_width / desired_width
height = current_height / desired_height
depth_factor = 1 / depth
width_factor = 1 / width
height_factor = 1 / height
# Rotate
img = ndimage.rotate(img, 90, reshape=False)
# Resize across z-axis
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
return img
def process_scan(path):
"""Read and resize volume"""
# Read scan
volume = read_scan(path)
# Normalize
volume = normalize(volume)
# Resize width, height and depth
volume = resize_volume(volume)
return volume
@tf.function
def rotate(volume):
"""Rotate the volume by a few degrees"""
def scipy_rotate(volume):
# define some rotation angles
angles = [-20, -10, -5, 5, 10, 20]
# pick angles at random
angle = random.choice(angles)
# rotate volume
volume = ndimage.rotate(volume, angle, reshape=False)
volume[volume < 0] = 0
volume[volume > 1] = 1
return volume
augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
return augmented_volume
def train_preprocessing(volume, label):
"""Process training data by rotating and adding a channel."""
# Rotate volume
volume = rotate(volume)
volume = tf.expand_dims(volume, axis=3)
return volume, label
def validation_preprocessing(volume, label):
"""Process validation data by only adding a channel."""
volume = tf.expand_dims(volume, axis=3)
return volume, label
def plot_slices(num_rows, num_columns, width, height, data):
"""Plot a montage of 20 CT slices"""
data = np.rot90(np.array(data))
data = np.transpose(data)
data = np.reshape(data, (num_rows, num_columns, width, height))
rows_data, columns_data = data.shape[0], data.shape[1]
heights = [slc[0].shape[0] for slc in data]
widths = [slc.shape[1] for slc in data[0]]
fig_width = 12.0
fig_height = fig_width * sum(heights) / sum(widths)
f, axarr = plt.subplots(
rows_data,
columns_data,
figsize=(fig_width, fig_height),
gridspec_kw={"height_ratios": heights},
)
for i in range(rows_data):
for j in range(columns_data):
axarr[i, j].imshow(data[i][j], cmap="gray")
axarr[i, j].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
plt.show()
def build_model(width=128, height=128, depth=64):
"""Build a 3D convolutional neural network model."""
inputs = tf.keras.Input((width, height, depth, 1))
x = Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
x = MaxPool3D(pool_size=2)(x)
x = BatchNormalization()(x)
x = Conv3D(filters=64, kernel_size=3, activation="relu")(x)
x = MaxPool3D(pool_size=2)(x)
x = BatchNormalization()(x)
x = Conv3D(filters=128, kernel_size=3, activation="relu")(x)
x = MaxPool3D(pool_size=2)(x)
x = BatchNormalization()(x)
x = Conv3D(filters=256, kernel_size=3, activation="relu")(x)
x = MaxPool3D(pool_size=2)(x)
x = BatchNormalization()(x)
x = GlobalAveragePooling3D()(x)
x = Dense(units=512, activation="relu")(x)
x = Dropout(0.3)(x)
outputs = Dense(units=1, activation="sigmoid")(x)
# Define the model.
model = tf.keras.Model(inputs, outputs, name="3dctscan")
return model