|
a |
|
b/3D/predict.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
# import packages |
|
|
4 |
from functools import partial |
|
|
5 |
import os, time |
|
|
6 |
import numpy as np |
|
|
7 |
from model import unet_model_3d |
|
|
8 |
from keras.models import Model |
|
|
9 |
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose |
|
|
10 |
from keras.optimizers import Adam |
|
|
11 |
from keras import callbacks |
|
|
12 |
from keras import backend as K |
|
|
13 |
from keras.utils import plot_model |
|
|
14 |
import nibabel as nib |
|
|
15 |
from PIL import Image |
|
|
16 |
from sklearn.feature_extraction.image import extract_patches |
|
|
17 |
|
|
|
18 |
# import load data |
|
|
19 |
from data_handling import load_train_data, load_validatation_data |
|
|
20 |
|
|
|
21 |
# import configurations |
|
|
22 |
import configs |
|
|
23 |
|
|
|
24 |
patch_size = configs.PATCH_SIZE |
|
|
25 |
config = dict() |
|
|
26 |
config["pool_size"] = (2, 2, 2) # pool size for the max pooling operations |
|
|
27 |
config["image_shape"] = (256, 128, 256) # This determines what shape the images will be cropped/resampled to. |
|
|
28 |
config["input_shape"] = (patch_size, patch_size, patch_size, 1) # switch to None to train on the whole image (64, 64, 64) (64, 64, 64) |
|
|
29 |
config["n_labels"] = 4 |
|
|
30 |
config["all_modalities"] = ['t1']#]["t1", "t1Gd", "flair", "t2"] |
|
|
31 |
config["training_modalities"] = config["all_modalities"] # change this if you want to only use some of the modalities |
|
|
32 |
config["nb_channels"] = len(config["training_modalities"]) |
|
|
33 |
config["deconvolution"] = False # if False, will use upsampling instead of deconvolution |
|
|
34 |
config["batch_size"] = 8 |
|
|
35 |
config["n_epochs"] = 500 # cutoff the training after this many epochs |
|
|
36 |
config["patience"] = 10 # learning rate will be reduced after this many epochs if the validation loss is not improving |
|
|
37 |
config["early_stop"] = 20 # training will be stopped after this many epochs without the validation loss improving |
|
|
38 |
config["initial_learning_rate"] = 0.0001 |
|
|
39 |
config["depth"] = configs.DEPTH |
|
|
40 |
config["learning_rate_drop"] = 0.5 |
|
|
41 |
|
|
|
42 |
extraction_reconstruct_step = configs.EXTRACTION_RECONSTRUCT_STEP |
|
|
43 |
image_type = '3d_patches' |
|
|
44 |
K.set_image_data_format('channels_last') # TF dimension ordering in this code |
|
|
45 |
|
|
|
46 |
image_type = configs.IMAGE_TYPE |
|
|
47 |
|
|
|
48 |
# init configs |
|
|
49 |
image_rows = configs.VOLUME_ROWS |
|
|
50 |
image_cols = configs.VOLUME_COLS |
|
|
51 |
image_depth = configs.VOLUME_DEPS |
|
|
52 |
num_classes = configs.NUM_CLASSES |
|
|
53 |
|
|
|
54 |
# patch extraction parameters |
|
|
55 |
|
|
|
56 |
BASE = configs.BASE |
|
|
57 |
smooth = configs.SMOOTH |
|
|
58 |
nb_epochs = configs.NUM_EPOCHS |
|
|
59 |
batch_size = configs.BATCH_SIZE |
|
|
60 |
unet_model_type = configs.MODEL |
|
|
61 |
extraction_step = 1 |
|
|
62 |
|
|
|
63 |
train_imgs_path = '../data_new/Validation_Set' |
|
|
64 |
checkpoint_filename = 'best_3classes_32_85_93_92_default.h5' |
|
|
65 |
write_path = 'predict' |
|
|
66 |
|
|
|
67 |
# for each slice estract patches and stack |
|
|
68 |
def create_slice_testing(img_dir_name): |
|
|
69 |
# empty matrix to hold patches |
|
|
70 |
patches_training_imgs_3d = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16') |
|
|
71 |
patches_training_gtruth_3d = np.empty(shape=[0, patch_size, patch_size, patch_size, num_classes], dtype='int16') |
|
|
72 |
images_train_dir = os.listdir(train_imgs_path) |
|
|
73 |
j = 0 |
|
|
74 |
|
|
|
75 |
# volume |
|
|
76 |
img_name = img_dir_name + '_hist.nii.gz' |
|
|
77 |
img_name = os.path.join(train_imgs_path, img_dir_name, img_name) |
|
|
78 |
print(img_name) |
|
|
79 |
|
|
|
80 |
# mask |
|
|
81 |
img_mask_name = img_dir_name + '_mask.nii.gz' |
|
|
82 |
img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name) |
|
|
83 |
|
|
|
84 |
# load volume and mask |
|
|
85 |
img = nib.load(img_name) |
|
|
86 |
img_data = img.get_data() |
|
|
87 |
img_data = np.squeeze(img_data) |
|
|
88 |
|
|
|
89 |
img_mask = nib.load(img_mask_name) |
|
|
90 |
img_mask_data = img_mask.get_data() |
|
|
91 |
img_mask_data = np.squeeze(img_mask_data) |
|
|
92 |
|
|
|
93 |
patches_training_imgs_3d_temp = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16') |
|
|
94 |
|
|
|
95 |
# extract patches of the jth volume image |
|
|
96 |
imgs_patches, rows, cols, depths = extract_3d_patches_one_slice(img_data, img_mask_data) |
|
|
97 |
|
|
|
98 |
# update database |
|
|
99 |
patches_training_imgs_3d_temp = np.append(patches_training_imgs_3d_temp, imgs_patches, axis=0) |
|
|
100 |
|
|
|
101 |
patches_training_imgs_3d = np.append(patches_training_imgs_3d, patches_training_imgs_3d_temp, axis=0) |
|
|
102 |
|
|
|
103 |
j += 1 |
|
|
104 |
|
|
|
105 |
# convert to single precision |
|
|
106 |
patches_training_imgs_3d = patches_training_imgs_3d.astype('float32') |
|
|
107 |
patches_training_imgs_3d = np.expand_dims(patches_training_imgs_3d, axis=4) |
|
|
108 |
return patches_training_imgs_3d, rows, cols, depths |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
# extract 3D patches |
|
|
112 |
def extract_3d_patches_one_slice(img_data, mask_data): |
|
|
113 |
patch_shape = (patch_size, patch_size, patch_size) |
|
|
114 |
|
|
|
115 |
# empty matrix to hold patches |
|
|
116 |
imgs_patches_per_volume = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16') |
|
|
117 |
mask_patches_per_slice = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16') |
|
|
118 |
STEP = patch_size-1 |
|
|
119 |
img_patches = extract_patches(img_data, patch_shape, extraction_reconstruct_step) |
|
|
120 |
mask_patches = extract_patches(mask_data, patch_shape, extraction_reconstruct_step) |
|
|
121 |
|
|
|
122 |
Sum = np.sum(mask_patches, axis=(3, 4, 5)) |
|
|
123 |
rows, cols, depths = np.nonzero(Sum) |
|
|
124 |
N = len(rows) |
|
|
125 |
# select non-zero patches index |
|
|
126 |
selected_img_patches = img_patches[rows, cols, depths, :, :, :] |
|
|
127 |
|
|
|
128 |
# update database |
|
|
129 |
imgs_patches_per_volume = np.append(imgs_patches_per_volume, selected_img_patches, axis=0) |
|
|
130 |
return imgs_patches_per_volume, rows, cols, depths |
|
|
131 |
|
|
|
132 |
# write predicted label to the final result |
|
|
133 |
def write_predict(imgs_valid_predict, rows, cols, depths): |
|
|
134 |
label_predicted = np.zeros((image_rows, image_cols, image_depth, num_classes)) |
|
|
135 |
label_predicted_filled = label_predicted |
|
|
136 |
label_final = np.zeros((image_rows, image_cols, image_depth)) |
|
|
137 |
|
|
|
138 |
for index in range(0, len(rows)): |
|
|
139 |
print ('Processing patch: ', index + 1, '/', len(rows)) |
|
|
140 |
row = rows[index]; col = cols[index]; dep = depths[index] |
|
|
141 |
start_row = row * extraction_reconstruct_step |
|
|
142 |
start_col = col * extraction_reconstruct_step |
|
|
143 |
start_dep = dep * extraction_reconstruct_step |
|
|
144 |
patch_volume = imgs_valid_predict[index,:,:,:,:] |
|
|
145 |
for i in range (0,patch_size): |
|
|
146 |
for j in range(0, patch_size): |
|
|
147 |
for k in range(0, patch_size): |
|
|
148 |
prob_class0_new = patch_volume[i][j][k][0] |
|
|
149 |
prob_class1_new = patch_volume[i][j][k][1] |
|
|
150 |
prob_class2_new = patch_volume[i][j][k][2] |
|
|
151 |
prob_class3_new = patch_volume[i][j][k][3] |
|
|
152 |
|
|
|
153 |
label_predicted_filled[start_row + i][start_col + j][start_dep + k][0] = prob_class0_new |
|
|
154 |
label_predicted_filled[start_row + i][start_col + j][start_dep + k][1] = prob_class1_new |
|
|
155 |
label_predicted_filled[start_row + i][start_col + j][start_dep + k][2] = prob_class2_new |
|
|
156 |
label_predicted_filled[start_row + i][start_col + j][start_dep + k][3] = prob_class3_new |
|
|
157 |
|
|
|
158 |
for i in range(0, 256): |
|
|
159 |
for j in range(0, 128): |
|
|
160 |
for k in range(0, 256): |
|
|
161 |
prob_class0 = label_predicted_filled[i][j][k][0] |
|
|
162 |
prob_class1 = label_predicted_filled[i][j][k][1] |
|
|
163 |
prob_class2 = label_predicted_filled[i][j][k][2] |
|
|
164 |
prob_class3 = label_predicted_filled[i][j][k][3] |
|
|
165 |
|
|
|
166 |
prob_max = max(prob_class0, prob_class1, prob_class2, prob_class3) |
|
|
167 |
if prob_class0 == prob_max: |
|
|
168 |
label_final[i][j][k] = 0 |
|
|
169 |
elif prob_class1 == prob_max: |
|
|
170 |
label_final[i][j][k] = 1 |
|
|
171 |
elif prob_class2 == prob_max: |
|
|
172 |
label_final[i][j][k] = 2 |
|
|
173 |
else: |
|
|
174 |
label_final[i][j][k] = 3 |
|
|
175 |
|
|
|
176 |
return label_final |
|
|
177 |
|
|
|
178 |
# predict function |
|
|
179 |
def predict(img_dir_name): |
|
|
180 |
# create a model |
|
|
181 |
model = unet_model_3d(input_shape=config["input_shape"], |
|
|
182 |
depth=config["depth"], |
|
|
183 |
pool_size=config["pool_size"], |
|
|
184 |
n_labels=config["n_labels"], |
|
|
185 |
initial_learning_rate=config["initial_learning_rate"], |
|
|
186 |
deconvolution=config["deconvolution"]) |
|
|
187 |
|
|
|
188 |
model.summary() |
|
|
189 |
|
|
|
190 |
checkpoint_filepath = 'outputs/' + checkpoint_filename |
|
|
191 |
model.load_weights(checkpoint_filepath) |
|
|
192 |
|
|
|
193 |
SegmentedVolume = np.zeros((image_rows,image_cols,image_depth)) |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
img_mask_name = img_dir_name + '_mask.nii.gz' |
|
|
197 |
img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name) |
|
|
198 |
|
|
|
199 |
img_mask = nib.load(img_mask_name) |
|
|
200 |
img_mask_data = img_mask.get_data() |
|
|
201 |
|
|
|
202 |
patches_training_imgs_3d, rows, cols, depths = create_slice_testing(img_dir_name) |
|
|
203 |
imgs_valid_predict = model.predict(patches_training_imgs_3d) |
|
|
204 |
label_final = write_predict(imgs_valid_predict, rows, cols, depths) |
|
|
205 |
|
|
|
206 |
for i in range(0, SegmentedVolume.shape[0]): |
|
|
207 |
for j in range(0, SegmentedVolume.shape[1]): |
|
|
208 |
for k in range(0, SegmentedVolume.shape[2]): |
|
|
209 |
if img_mask_data.item((i, j, k)) == 1: |
|
|
210 |
SegmentedVolume.itemset((i,j,k), label_final.item((i, j,k))) |
|
|
211 |
else: |
|
|
212 |
label_final.itemset((i, j,k), 0) |
|
|
213 |
|
|
|
214 |
print ('done') |
|
|
215 |
|
|
|
216 |
|
|
|
217 |
|
|
|
218 |
data = SegmentedVolume |
|
|
219 |
img = nib.Nifti1Image(data, np.eye(4)) |
|
|
220 |
if num_classes == 3: |
|
|
221 |
img_name = img_dir_name + '_predicted_3class_unet3d.nii.gz' |
|
|
222 |
else: |
|
|
223 |
img_name = img_dir_name + '_predicted_4class_unet3d_extract12_hist15.nii.gz' |
|
|
224 |
nib.save(img, os.path.join('../data_new', write_path, img_name)) |
|
|
225 |
print('-' * 30) |
|
|
226 |
|
|
|
227 |
# main |
|
|
228 |
if __name__ == '__main__': |
|
|
229 |
# folder to hold outputs |
|
|
230 |
if 'outputs' not in os.listdir(os.curdir): |
|
|
231 |
os.mkdir('outputs') |
|
|
232 |
|
|
|
233 |
images_train_dir = os.listdir(train_imgs_path) |
|
|
234 |
|
|
|
235 |
j=0 |
|
|
236 |
|
|
|
237 |
for img_dir_name in images_train_dir: |
|
|
238 |
j=j+1 |
|
|
239 |
if j: |
|
|
240 |
start_time = time.time() |
|
|
241 |
|
|
|
242 |
print('*'*50) |
|
|
243 |
print('Segmenting: volume {0} / {1} volume images'.format(j, len(images_train_dir))) |
|
|
244 |
# print('-'*30) |
|
|
245 |
if num_classes == 3: |
|
|
246 |
img_name = img_dir_name + '_predicted_3class_unet3d.nii.gz' |
|
|
247 |
else: |
|
|
248 |
img_name = img_dir_name + '_predicted_4class_unet3d_extract12_hist15.nii.gz' |
|
|
249 |
print ('Path: ', os.path.join('../data_new', write_path, img_name)) |
|
|
250 |
print('*' * 50) |
|
|
251 |
predict(img_dir_name) |
|
|
252 |
end_time = time.time() |
|
|
253 |
elapsed_time = end_time - start_time |
|
|
254 |
print ('Elapsed time is: ', elapsed_time) |