a b/helper.py
1
import matplotlib.pyplot as plt
2
import nibabel as nib
3
import numpy as np
4
import random
5
from scipy import ndimage
6
import tensorflow as tf
7
from tensorflow.keras.layers import Dense, Conv3D, MaxPool3D, BatchNormalization, GlobalAveragePooling3D, Dropout
8
9
10
def read_scan(filepath):
11
    """Read and load volume"""
12
    # Read file
13
    scan = nib.load(filepath)
14
    # Get raw data
15
    scan = scan.get_fdata()
16
    return scan
17
18
19
def normalize(volume):
20
    """Normalize the volume"""
21
    min = -1000
22
    max = 400
23
    volume[volume < min] = min
24
    volume[volume > max] = max
25
    volume = (volume - min) / (max - min)
26
    volume = volume.astype("float32")
27
    return volume
28
29
30
def resize_volume(img):
31
    """Resize across z-axis"""
32
    # Set the desired depth
33
    desired_depth = 64
34
    desired_width = 128
35
    desired_height = 128
36
    # Get current depth
37
    current_depth = img.shape[-1]
38
    current_width = img.shape[0]
39
    current_height = img.shape[1]
40
    # Compute depth factor
41
    depth = current_depth / desired_depth
42
    width = current_width / desired_width
43
    height = current_height / desired_height
44
    depth_factor = 1 / depth
45
    width_factor = 1 / width
46
    height_factor = 1 / height
47
    # Rotate
48
    img = ndimage.rotate(img, 90, reshape=False)
49
    # Resize across z-axis
50
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
51
    return img
52
53
54
def process_scan(path):
55
    """Read and resize volume"""
56
    # Read scan
57
    volume = read_scan(path)
58
    # Normalize
59
    volume = normalize(volume)
60
    # Resize width, height and depth
61
    volume = resize_volume(volume)
62
    return volume
63
64
@tf.function
65
def rotate(volume):
66
    """Rotate the volume by a few degrees"""
67
68
    def scipy_rotate(volume):
69
        # define some rotation angles
70
        angles = [-20, -10, -5, 5, 10, 20]
71
        # pick angles at random
72
        angle = random.choice(angles)
73
        # rotate volume
74
        volume = ndimage.rotate(volume, angle, reshape=False)
75
        volume[volume < 0] = 0
76
        volume[volume > 1] = 1
77
        return volume
78
79
    augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
80
    return augmented_volume
81
82
83
def train_preprocessing(volume, label):
84
    """Process training data by rotating and adding a channel."""
85
    # Rotate volume
86
    volume = rotate(volume)
87
    volume = tf.expand_dims(volume, axis=3)
88
    return volume, label
89
90
91
def validation_preprocessing(volume, label):
92
    """Process validation data by only adding a channel."""
93
    volume = tf.expand_dims(volume, axis=3)
94
    return volume, label
95
96
97
98
99
def plot_slices(num_rows, num_columns, width, height, data):
100
    """Plot a montage of 20 CT slices"""
101
    data = np.rot90(np.array(data))
102
    data = np.transpose(data)
103
    data = np.reshape(data, (num_rows, num_columns, width, height))
104
    rows_data, columns_data = data.shape[0], data.shape[1]
105
    heights = [slc[0].shape[0] for slc in data]
106
    widths = [slc.shape[1] for slc in data[0]]
107
    fig_width = 12.0
108
    fig_height = fig_width * sum(heights) / sum(widths)
109
    f, axarr = plt.subplots(
110
        rows_data,
111
        columns_data,
112
        figsize=(fig_width, fig_height),
113
        gridspec_kw={"height_ratios": heights},
114
    )
115
    for i in range(rows_data):
116
        for j in range(columns_data):
117
            axarr[i, j].imshow(data[i][j], cmap="gray")
118
            axarr[i, j].axis("off")
119
    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
120
    plt.show()
121
122
123
def build_model(width=128, height=128, depth=64):
124
    """Build a 3D convolutional neural network model."""
125
126
    inputs = tf.keras.Input((width, height, depth, 1))
127
128
    x = Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
129
    x = MaxPool3D(pool_size=2)(x)
130
    x = BatchNormalization()(x)
131
132
    x = Conv3D(filters=64, kernel_size=3, activation="relu")(x)
133
    x = MaxPool3D(pool_size=2)(x)
134
    x = BatchNormalization()(x)
135
136
    x = Conv3D(filters=128, kernel_size=3, activation="relu")(x)
137
    x = MaxPool3D(pool_size=2)(x)
138
    x = BatchNormalization()(x)
139
140
    x = Conv3D(filters=256, kernel_size=3, activation="relu")(x)
141
    x = MaxPool3D(pool_size=2)(x)
142
    x = BatchNormalization()(x)
143
144
    x = GlobalAveragePooling3D()(x)
145
    x = Dense(units=512, activation="relu")(x)
146
    x = Dropout(0.3)(x)
147
148
    outputs = Dense(units=1, activation="sigmoid")(x)
149
150
    # Define the model.
151
    model = tf.keras.Model(inputs, outputs, name="3dctscan")
152
    return model
153