|
a |
|
b/2D/predict_2d.py |
|
|
1 |
|
|
|
2 |
from __future__ import print_function |
|
|
3 |
|
|
|
4 |
# import packages |
|
|
5 |
from functools import partial |
|
|
6 |
import os |
|
|
7 |
import time |
|
|
8 |
import numpy as np |
|
|
9 |
from keras.models import Model |
|
|
10 |
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose |
|
|
11 |
from keras.optimizers import Adam |
|
|
12 |
from keras import callbacks |
|
|
13 |
from keras import backend as K |
|
|
14 |
from keras.utils import plot_model |
|
|
15 |
import nibabel as nib |
|
|
16 |
from PIL import Image |
|
|
17 |
from sklearn.feature_extraction.image import extract_patches |
|
|
18 |
|
|
|
19 |
# import load data |
|
|
20 |
from data_handling_2d_patch import load_train_data, load_validatation_data |
|
|
21 |
from train_main_2d_patch import get_unet_default, get_unet_reduced, get_unet_extended |
|
|
22 |
|
|
|
23 |
# import configurations |
|
|
24 |
import configs |
|
|
25 |
|
|
|
26 |
K.set_image_data_format('channels_last') # TF dimension ordering in this code |
|
|
27 |
|
|
|
28 |
image_type = configs.IMAGE_TYPE |
|
|
29 |
|
|
|
30 |
# init configs |
|
|
31 |
image_rows = configs.VOLUME_ROWS |
|
|
32 |
image_cols = configs.VOLUME_COLS |
|
|
33 |
image_depth = configs.VOLUME_DEPS |
|
|
34 |
num_classes = configs.NUM_CLASSES |
|
|
35 |
|
|
|
36 |
# patch extraction parameters |
|
|
37 |
patch_size = configs.PATCH_SIZE |
|
|
38 |
BASE = configs.BASE |
|
|
39 |
smooth = configs.SMOOTH |
|
|
40 |
nb_epochs = configs.NUM_EPOCHS |
|
|
41 |
batch_size = configs.BATCH_SIZE |
|
|
42 |
unet_model_type = configs.MODEL |
|
|
43 |
extraction_step = 1 |
|
|
44 |
|
|
|
45 |
extraction_reconstruct_step = configs.extraction_reconstruct_step |
|
|
46 |
|
|
|
47 |
# init |
|
|
48 |
train_imgs_path = '../data_new/Test_Set' |
|
|
49 |
print('path: ', train_imgs_path) |
|
|
50 |
checkpoint_filename = 'best_4classes_32_default_tuned_8925.h5' |
|
|
51 |
print('weight file: ', checkpoint_filename) |
|
|
52 |
write_path = 'predict2D' |
|
|
53 |
|
|
|
54 |
# for each slice estract patches and stack |
|
|
55 |
def create_slice_testing(slice_number, img_dir_name): |
|
|
56 |
# empty matrix to hold patches |
|
|
57 |
patches_training_imgs_2d = np.empty(shape=[0, patch_size, patch_size], dtype='int16') |
|
|
58 |
patches_training_gtruth_2d = np.empty(shape=[0, patch_size, patch_size, num_classes], dtype='int16') |
|
|
59 |
images_train_dir = os.listdir(train_imgs_path) |
|
|
60 |
j = 0 |
|
|
61 |
|
|
|
62 |
# volume |
|
|
63 |
img_name = img_dir_name + '_hist.nii.gz' |
|
|
64 |
print('Image: ', img_name) |
|
|
65 |
img_name = os.path.join(train_imgs_path, img_dir_name, img_name) |
|
|
66 |
|
|
|
67 |
# mask |
|
|
68 |
img_mask_name = img_dir_name + '_mask.nii.gz' |
|
|
69 |
img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name) |
|
|
70 |
|
|
|
71 |
# load volume and mask |
|
|
72 |
img = nib.load(img_name) |
|
|
73 |
img_data = img.get_data() |
|
|
74 |
img_data = np.squeeze(img_data) |
|
|
75 |
|
|
|
76 |
img_mask = nib.load(img_mask_name) |
|
|
77 |
img_mask_data = img_mask.get_data() |
|
|
78 |
img_mask_data = np.squeeze(img_mask_data) |
|
|
79 |
|
|
|
80 |
patches_training_imgs_2d_temp = np.empty(shape=[0, patch_size, patch_size], dtype='int16') |
|
|
81 |
patches_training_gtruth_2d_temp = np.empty(shape=[0, patch_size, patch_size, num_classes], dtype='int16') |
|
|
82 |
|
|
|
83 |
rows = []; cols = [] |
|
|
84 |
if np.count_nonzero(img_mask_data[:, :, slice_number]) and np.count_nonzero(img_data[:, :, slice_number]): |
|
|
85 |
# extract patches of the jth volume image |
|
|
86 |
imgs_patches, rows, cols = extract_2d_patches_one_slice(img_data[:, :, slice_number], |
|
|
87 |
img_mask_data[:, :, slice_number]) |
|
|
88 |
|
|
|
89 |
# update database |
|
|
90 |
patches_training_imgs_2d_temp = np.append(patches_training_imgs_2d_temp, imgs_patches, axis=0) |
|
|
91 |
|
|
|
92 |
patches_training_imgs_2d = np.append(patches_training_imgs_2d, patches_training_imgs_2d_temp, axis=0) |
|
|
93 |
j += 1 |
|
|
94 |
|
|
|
95 |
X = patches_training_imgs_2d.shape |
|
|
96 |
Y = patches_training_gtruth_2d.shape |
|
|
97 |
|
|
|
98 |
# convert to single precision |
|
|
99 |
patches_training_imgs_2d = patches_training_imgs_2d.astype('float32') |
|
|
100 |
patches_training_imgs_2d = np.expand_dims(patches_training_imgs_2d, axis=3) |
|
|
101 |
|
|
|
102 |
S = patches_training_imgs_2d.shape |
|
|
103 |
|
|
|
104 |
label_predicted = np.zeros((img_data.shape[0], img_data.shape[1]), dtype=np.uint8) |
|
|
105 |
|
|
|
106 |
return label_predicted, patches_training_imgs_2d, rows, cols |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
# extract patches in one slice |
|
|
110 |
def extract_2d_patches_one_slice(img_data, mask_data): |
|
|
111 |
patch_shape = (patch_size, patch_size) |
|
|
112 |
|
|
|
113 |
# empty matrix to hold patches |
|
|
114 |
imgs_patches_per_slice = np.empty(shape=[0, patch_size, patch_size], dtype='int16') |
|
|
115 |
|
|
|
116 |
img_patches = extract_patches(img_data, patch_shape, extraction_reconstruct_step) |
|
|
117 |
mask_patches = extract_patches(mask_data, patch_shape, extraction_reconstruct_step) |
|
|
118 |
|
|
|
119 |
Sum = np.sum(mask_patches, axis=(2, 3)) |
|
|
120 |
rows, cols = np.nonzero(Sum) |
|
|
121 |
|
|
|
122 |
N = len(rows) |
|
|
123 |
# select non-zero patches index |
|
|
124 |
selected_img_patches = img_patches[rows, cols, :, :] |
|
|
125 |
|
|
|
126 |
# update database |
|
|
127 |
imgs_patches_per_slice = np.append(imgs_patches_per_slice, selected_img_patches, axis=0) |
|
|
128 |
return imgs_patches_per_slice, rows, cols |
|
|
129 |
|
|
|
130 |
# write predicted label to the final result |
|
|
131 |
def write_slice_predict(imgs_valid_predict, rows, cols): |
|
|
132 |
label_predicted_filled = np.zeros((image_rows, image_cols, num_classes)) |
|
|
133 |
label_final = np.zeros((image_rows, image_cols)) |
|
|
134 |
|
|
|
135 |
Count = len(rows) |
|
|
136 |
count_write = len(rows) |
|
|
137 |
|
|
|
138 |
for index in range(0, len(rows)): |
|
|
139 |
row = rows[index]; col = cols[index] |
|
|
140 |
start_row = row * extraction_reconstruct_step |
|
|
141 |
start_col = col * extraction_reconstruct_step |
|
|
142 |
patch_volume = imgs_valid_predict[index, :, :, :] |
|
|
143 |
for i in range(0, patch_size): |
|
|
144 |
for j in range(0, patch_size): |
|
|
145 |
prob_class0_new = patch_volume[i][j][0] |
|
|
146 |
prob_class1_new = patch_volume[i][j][1] |
|
|
147 |
prob_class2_new = patch_volume[i][j][2] |
|
|
148 |
prob_class3_new = patch_volume[i][j][3] |
|
|
149 |
|
|
|
150 |
label_predicted_filled[start_row + i][start_col + j][0] = prob_class0_new |
|
|
151 |
label_predicted_filled[start_row + i][start_col + j][1] = prob_class1_new |
|
|
152 |
label_predicted_filled[start_row + i][start_col + j][2] = prob_class2_new |
|
|
153 |
label_predicted_filled[start_row + i][start_col + j][3] = prob_class3_new |
|
|
154 |
|
|
|
155 |
for i in range(0, 256): |
|
|
156 |
for j in range(0, 128): |
|
|
157 |
prob_class0 = label_predicted_filled[i][j][0] |
|
|
158 |
prob_class1 = label_predicted_filled[i][j][1] |
|
|
159 |
prob_class2 = label_predicted_filled[i][j][2] |
|
|
160 |
prob_class3 = label_predicted_filled[i][j][3] |
|
|
161 |
|
|
|
162 |
prob_max = max(prob_class0, prob_class1, prob_class2, prob_class3) |
|
|
163 |
if prob_class0 == prob_max: |
|
|
164 |
label_final[i][j] = 0 |
|
|
165 |
elif prob_class1 == prob_max: |
|
|
166 |
label_final[i][j] = 1 |
|
|
167 |
elif prob_class2 == prob_max: |
|
|
168 |
label_final[i][j] = 2 |
|
|
169 |
else: |
|
|
170 |
label_final[i][j] = 3 |
|
|
171 |
|
|
|
172 |
print('Number of processed patches: ', count_write) |
|
|
173 |
print('Number of extracted patches: ', Count) |
|
|
174 |
return label_final |
|
|
175 |
|
|
|
176 |
# predict function |
|
|
177 |
def predict(img_dir_name): |
|
|
178 |
if unet_model_type == 'default': |
|
|
179 |
model = get_unet_default() |
|
|
180 |
elif unet_model_type == 'reduced': |
|
|
181 |
model = get_unet_reduced() |
|
|
182 |
elif unet_model_type == 'extended': |
|
|
183 |
model = get_unet_extended() |
|
|
184 |
|
|
|
185 |
checkpoint_filepath = 'outputs/' + checkpoint_filename |
|
|
186 |
model.load_weights(checkpoint_filepath) |
|
|
187 |
model.summary() |
|
|
188 |
|
|
|
189 |
SegmentedVolume = np.zeros((image_rows,image_cols,image_depth)) |
|
|
190 |
|
|
|
191 |
img_mask_name = img_dir_name + '_mask.nii.gz' |
|
|
192 |
img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name) |
|
|
193 |
|
|
|
194 |
img_mask = nib.load(img_mask_name) |
|
|
195 |
img_mask_data = img_mask.get_data() |
|
|
196 |
|
|
|
197 |
# for each slice, extract patches and predict |
|
|
198 |
for iSlice in range(0,256): |
|
|
199 |
mask = img_mask_data[2:254, 2:127, iSlice] |
|
|
200 |
|
|
|
201 |
if np.sum(mask, axis=(0,1))>0: |
|
|
202 |
print('-' * 30) |
|
|
203 |
print('Slice number: ', iSlice) |
|
|
204 |
label_predicted, patches_training_imgs_2d, rows, cols = create_slice_testing(iSlice, img_dir_name) |
|
|
205 |
imgs_valid_predict = model.predict(patches_training_imgs_2d) |
|
|
206 |
label_predicted_filled = write_slice_predict(imgs_valid_predict, rows, cols) |
|
|
207 |
|
|
|
208 |
for i in range(0, SegmentedVolume.shape[0]): |
|
|
209 |
for j in range(0, SegmentedVolume.shape[1]): |
|
|
210 |
if img_mask_data.item((i, j, iSlice)) == 1: |
|
|
211 |
SegmentedVolume.itemset((i,j,iSlice), label_predicted_filled.item((i, j))) |
|
|
212 |
else: |
|
|
213 |
label_predicted_filled.itemset((i, j), 0) |
|
|
214 |
print ('done') |
|
|
215 |
|
|
|
216 |
# utilize mask to write output |
|
|
217 |
data = SegmentedVolume |
|
|
218 |
img = nib.Nifti1Image(data, np.eye(4)) |
|
|
219 |
if num_classes == 3: |
|
|
220 |
img_name = img_dir_name + '_predicted_3class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz' |
|
|
221 |
else: |
|
|
222 |
img_name = img_dir_name + '_predicted_4class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz' |
|
|
223 |
nib.save(img, os.path.join('../data_new', write_path, img_name)) |
|
|
224 |
print('-' * 30) |
|
|
225 |
|
|
|
226 |
# main |
|
|
227 |
if __name__ == '__main__': |
|
|
228 |
# folder to hold outputs |
|
|
229 |
if 'outputs' not in os.listdir(os.curdir): |
|
|
230 |
os.mkdir('outputs') |
|
|
231 |
|
|
|
232 |
images_train_dir = os.listdir(train_imgs_path) |
|
|
233 |
|
|
|
234 |
j=0 |
|
|
235 |
for img_dir_name in images_train_dir: |
|
|
236 |
j=j+1 |
|
|
237 |
if j: |
|
|
238 |
start_time = time.time() |
|
|
239 |
print('*'*50) |
|
|
240 |
print('Segmenting: volume {0} / {1} volume images'.format(j, len(images_train_dir))) |
|
|
241 |
# print('-'*30) |
|
|
242 |
if num_classes == 3: |
|
|
243 |
img_name = img_dir_name + '_predicted_3class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz' |
|
|
244 |
else: |
|
|
245 |
img_name = img_dir_name + '_predicted_4class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz' |
|
|
246 |
print ('Path: ', os.path.join('../data_new', write_path, img_name)) |
|
|
247 |
print('*' * 50) |
|
|
248 |
predict(img_dir_name) |
|
|
249 |
end_time = time.time() |
|
|
250 |
elapsed_time = end_time - start_time |
|
|
251 |
print ('Elapsed time is: ', elapsed_time) |