|
a |
|
b/helper.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import nibabel as nib |
|
|
3 |
import numpy as np |
|
|
4 |
import random |
|
|
5 |
from scipy import ndimage |
|
|
6 |
import tensorflow as tf |
|
|
7 |
from tensorflow.keras.layers import Dense, Conv3D, MaxPool3D, BatchNormalization, GlobalAveragePooling3D, Dropout |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def read_scan(filepath): |
|
|
11 |
"""Read and load volume""" |
|
|
12 |
# Read file |
|
|
13 |
scan = nib.load(filepath) |
|
|
14 |
# Get raw data |
|
|
15 |
scan = scan.get_fdata() |
|
|
16 |
return scan |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def normalize(volume): |
|
|
20 |
"""Normalize the volume""" |
|
|
21 |
min = -1000 |
|
|
22 |
max = 400 |
|
|
23 |
volume[volume < min] = min |
|
|
24 |
volume[volume > max] = max |
|
|
25 |
volume = (volume - min) / (max - min) |
|
|
26 |
volume = volume.astype("float32") |
|
|
27 |
return volume |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def resize_volume(img): |
|
|
31 |
"""Resize across z-axis""" |
|
|
32 |
# Set the desired depth |
|
|
33 |
desired_depth = 64 |
|
|
34 |
desired_width = 128 |
|
|
35 |
desired_height = 128 |
|
|
36 |
# Get current depth |
|
|
37 |
current_depth = img.shape[-1] |
|
|
38 |
current_width = img.shape[0] |
|
|
39 |
current_height = img.shape[1] |
|
|
40 |
# Compute depth factor |
|
|
41 |
depth = current_depth / desired_depth |
|
|
42 |
width = current_width / desired_width |
|
|
43 |
height = current_height / desired_height |
|
|
44 |
depth_factor = 1 / depth |
|
|
45 |
width_factor = 1 / width |
|
|
46 |
height_factor = 1 / height |
|
|
47 |
# Rotate |
|
|
48 |
img = ndimage.rotate(img, 90, reshape=False) |
|
|
49 |
# Resize across z-axis |
|
|
50 |
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1) |
|
|
51 |
return img |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def process_scan(path): |
|
|
55 |
"""Read and resize volume""" |
|
|
56 |
# Read scan |
|
|
57 |
volume = read_scan(path) |
|
|
58 |
# Normalize |
|
|
59 |
volume = normalize(volume) |
|
|
60 |
# Resize width, height and depth |
|
|
61 |
volume = resize_volume(volume) |
|
|
62 |
return volume |
|
|
63 |
|
|
|
64 |
@tf.function |
|
|
65 |
def rotate(volume): |
|
|
66 |
"""Rotate the volume by a few degrees""" |
|
|
67 |
|
|
|
68 |
def scipy_rotate(volume): |
|
|
69 |
# define some rotation angles |
|
|
70 |
angles = [-20, -10, -5, 5, 10, 20] |
|
|
71 |
# pick angles at random |
|
|
72 |
angle = random.choice(angles) |
|
|
73 |
# rotate volume |
|
|
74 |
volume = ndimage.rotate(volume, angle, reshape=False) |
|
|
75 |
volume[volume < 0] = 0 |
|
|
76 |
volume[volume > 1] = 1 |
|
|
77 |
return volume |
|
|
78 |
|
|
|
79 |
augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32) |
|
|
80 |
return augmented_volume |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
def train_preprocessing(volume, label): |
|
|
84 |
"""Process training data by rotating and adding a channel.""" |
|
|
85 |
# Rotate volume |
|
|
86 |
volume = rotate(volume) |
|
|
87 |
volume = tf.expand_dims(volume, axis=3) |
|
|
88 |
return volume, label |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
def validation_preprocessing(volume, label): |
|
|
92 |
"""Process validation data by only adding a channel.""" |
|
|
93 |
volume = tf.expand_dims(volume, axis=3) |
|
|
94 |
return volume, label |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
|
|
|
98 |
|
|
|
99 |
def plot_slices(num_rows, num_columns, width, height, data): |
|
|
100 |
"""Plot a montage of 20 CT slices""" |
|
|
101 |
data = np.rot90(np.array(data)) |
|
|
102 |
data = np.transpose(data) |
|
|
103 |
data = np.reshape(data, (num_rows, num_columns, width, height)) |
|
|
104 |
rows_data, columns_data = data.shape[0], data.shape[1] |
|
|
105 |
heights = [slc[0].shape[0] for slc in data] |
|
|
106 |
widths = [slc.shape[1] for slc in data[0]] |
|
|
107 |
fig_width = 12.0 |
|
|
108 |
fig_height = fig_width * sum(heights) / sum(widths) |
|
|
109 |
f, axarr = plt.subplots( |
|
|
110 |
rows_data, |
|
|
111 |
columns_data, |
|
|
112 |
figsize=(fig_width, fig_height), |
|
|
113 |
gridspec_kw={"height_ratios": heights}, |
|
|
114 |
) |
|
|
115 |
for i in range(rows_data): |
|
|
116 |
for j in range(columns_data): |
|
|
117 |
axarr[i, j].imshow(data[i][j], cmap="gray") |
|
|
118 |
axarr[i, j].axis("off") |
|
|
119 |
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) |
|
|
120 |
plt.show() |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
def build_model(width=128, height=128, depth=64): |
|
|
124 |
"""Build a 3D convolutional neural network model.""" |
|
|
125 |
|
|
|
126 |
inputs = tf.keras.Input((width, height, depth, 1)) |
|
|
127 |
|
|
|
128 |
x = Conv3D(filters=64, kernel_size=3, activation="relu")(inputs) |
|
|
129 |
x = MaxPool3D(pool_size=2)(x) |
|
|
130 |
x = BatchNormalization()(x) |
|
|
131 |
|
|
|
132 |
x = Conv3D(filters=64, kernel_size=3, activation="relu")(x) |
|
|
133 |
x = MaxPool3D(pool_size=2)(x) |
|
|
134 |
x = BatchNormalization()(x) |
|
|
135 |
|
|
|
136 |
x = Conv3D(filters=128, kernel_size=3, activation="relu")(x) |
|
|
137 |
x = MaxPool3D(pool_size=2)(x) |
|
|
138 |
x = BatchNormalization()(x) |
|
|
139 |
|
|
|
140 |
x = Conv3D(filters=256, kernel_size=3, activation="relu")(x) |
|
|
141 |
x = MaxPool3D(pool_size=2)(x) |
|
|
142 |
x = BatchNormalization()(x) |
|
|
143 |
|
|
|
144 |
x = GlobalAveragePooling3D()(x) |
|
|
145 |
x = Dense(units=512, activation="relu")(x) |
|
|
146 |
x = Dropout(0.3)(x) |
|
|
147 |
|
|
|
148 |
outputs = Dense(units=1, activation="sigmoid")(x) |
|
|
149 |
|
|
|
150 |
# Define the model. |
|
|
151 |
model = tf.keras.Model(inputs, outputs, name="3dctscan") |
|
|
152 |
return model |
|
|
153 |
|