a b/2D/predict_2d.py
1
2
from __future__ import print_function
3
4
# import packages
5
from functools import partial
6
import os
7
import time
8
import numpy as np
9
from keras.models import Model
10
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
11
from keras.optimizers import Adam
12
from keras import callbacks
13
from keras import backend as K
14
from keras.utils import plot_model
15
import nibabel as nib
16
from PIL import Image
17
from sklearn.feature_extraction.image import extract_patches
18
19
# import load data
20
from data_handling_2d_patch import load_train_data, load_validatation_data
21
from train_main_2d_patch import get_unet_default, get_unet_reduced, get_unet_extended
22
23
# import configurations
24
import configs
25
26
K.set_image_data_format('channels_last')  # TF dimension ordering in this code
27
28
image_type = configs.IMAGE_TYPE
29
30
# init configs
31
image_rows = configs.VOLUME_ROWS
32
image_cols = configs.VOLUME_COLS
33
image_depth = configs.VOLUME_DEPS
34
num_classes = configs.NUM_CLASSES
35
36
# patch extraction parameters
37
patch_size = configs.PATCH_SIZE
38
BASE = configs.BASE
39
smooth = configs.SMOOTH
40
nb_epochs  = configs.NUM_EPOCHS
41
batch_size  = configs.BATCH_SIZE
42
unet_model_type = configs.MODEL
43
extraction_step = 1
44
45
extraction_reconstruct_step = configs.extraction_reconstruct_step
46
47
# init
48
train_imgs_path = '../data_new/Test_Set'
49
print('path: ', train_imgs_path)
50
checkpoint_filename = 'best_4classes_32_default_tuned_8925.h5'
51
print('weight file: ', checkpoint_filename)
52
write_path = 'predict2D'
53
54
# for each slice estract patches and stack
55
def create_slice_testing(slice_number, img_dir_name):
56
    # empty matrix to hold patches
57
    patches_training_imgs_2d = np.empty(shape=[0, patch_size, patch_size], dtype='int16')
58
    patches_training_gtruth_2d = np.empty(shape=[0, patch_size, patch_size, num_classes], dtype='int16')
59
    images_train_dir = os.listdir(train_imgs_path)
60
    j = 0
61
62
    # volume 
63
    img_name = img_dir_name + '_hist.nii.gz'
64
    print('Image: ', img_name)
65
    img_name = os.path.join(train_imgs_path, img_dir_name, img_name)
66
    
67
    # mask 
68
    img_mask_name = img_dir_name + '_mask.nii.gz'
69
    img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name)
70
71
    # load volume and mask
72
    img = nib.load(img_name)
73
    img_data = img.get_data()
74
    img_data = np.squeeze(img_data)
75
76
    img_mask = nib.load(img_mask_name)
77
    img_mask_data = img_mask.get_data()
78
    img_mask_data = np.squeeze(img_mask_data)
79
80
    patches_training_imgs_2d_temp = np.empty(shape=[0, patch_size, patch_size], dtype='int16')
81
    patches_training_gtruth_2d_temp = np.empty(shape=[0, patch_size, patch_size, num_classes], dtype='int16')
82
83
    rows = []; cols = []
84
    if np.count_nonzero(img_mask_data[:, :, slice_number]) and np.count_nonzero(img_data[:, :, slice_number]):
85
        # extract patches of the jth volume image
86
        imgs_patches, rows, cols = extract_2d_patches_one_slice(img_data[:, :, slice_number],
87
                                                                img_mask_data[:, :, slice_number])
88
89
        # update database
90
        patches_training_imgs_2d_temp = np.append(patches_training_imgs_2d_temp, imgs_patches, axis=0)
91
92
    patches_training_imgs_2d = np.append(patches_training_imgs_2d, patches_training_imgs_2d_temp, axis=0)
93
    j += 1
94
95
    X = patches_training_imgs_2d.shape
96
    Y = patches_training_gtruth_2d.shape
97
98
    # convert to single precision
99
    patches_training_imgs_2d = patches_training_imgs_2d.astype('float32')
100
    patches_training_imgs_2d = np.expand_dims(patches_training_imgs_2d, axis=3)
101
102
    S = patches_training_imgs_2d.shape
103
104
    label_predicted = np.zeros((img_data.shape[0], img_data.shape[1]), dtype=np.uint8)
105
106
    return label_predicted, patches_training_imgs_2d, rows, cols
107
108
109
# extract patches in one slice
110
def extract_2d_patches_one_slice(img_data, mask_data):
111
    patch_shape = (patch_size, patch_size)
112
113
    # empty matrix to hold patches
114
    imgs_patches_per_slice = np.empty(shape=[0, patch_size, patch_size], dtype='int16')
115
116
    img_patches = extract_patches(img_data, patch_shape, extraction_reconstruct_step)
117
    mask_patches = extract_patches(mask_data, patch_shape, extraction_reconstruct_step)
118
119
    Sum = np.sum(mask_patches, axis=(2, 3))
120
    rows, cols = np.nonzero(Sum)
121
   
122
    N = len(rows)
123
    # select non-zero patches index
124
    selected_img_patches = img_patches[rows, cols, :, :]
125
126
    # update database
127
    imgs_patches_per_slice = np.append(imgs_patches_per_slice, selected_img_patches, axis=0)
128
    return imgs_patches_per_slice, rows, cols
129
130
# write predicted label to the final result
131
def write_slice_predict(imgs_valid_predict, rows, cols):
132
    label_predicted_filled = np.zeros((image_rows, image_cols, num_classes))
133
    label_final = np.zeros((image_rows, image_cols))
134
135
    Count = len(rows)
136
    count_write = len(rows)
137
138
    for index in range(0, len(rows)):
139
        row = rows[index]; col = cols[index]
140
        start_row = row * extraction_reconstruct_step
141
        start_col = col * extraction_reconstruct_step
142
        patch_volume = imgs_valid_predict[index, :, :, :]
143
        for i in range(0, patch_size):
144
            for j in range(0, patch_size):
145
                prob_class0_new = patch_volume[i][j][0]
146
                prob_class1_new = patch_volume[i][j][1]
147
                prob_class2_new = patch_volume[i][j][2]
148
                prob_class3_new = patch_volume[i][j][3]
149
150
                label_predicted_filled[start_row + i][start_col + j][0] = prob_class0_new
151
                label_predicted_filled[start_row + i][start_col + j][1] = prob_class1_new
152
                label_predicted_filled[start_row + i][start_col + j][2] = prob_class2_new
153
                label_predicted_filled[start_row + i][start_col + j][3] = prob_class3_new
154
155
    for i in range(0, 256):
156
        for j in range(0, 128):
157
                prob_class0 = label_predicted_filled[i][j][0]
158
                prob_class1 = label_predicted_filled[i][j][1]
159
                prob_class2 = label_predicted_filled[i][j][2]
160
                prob_class3 = label_predicted_filled[i][j][3]
161
162
                prob_max = max(prob_class0, prob_class1, prob_class2, prob_class3)
163
                if prob_class0 == prob_max:
164
                    label_final[i][j] = 0
165
                elif prob_class1 == prob_max:
166
                    label_final[i][j] = 1
167
                elif prob_class2 == prob_max:
168
                    label_final[i][j] = 2
169
                else:
170
                    label_final[i][j] = 3
171
172
    print('Number of processed patches: ', count_write)
173
    print('Number of extracted patches: ', Count)
174
    return label_final
175
176
# predict function
177
def predict(img_dir_name):
178
    if unet_model_type == 'default':
179
        model = get_unet_default()
180
    elif unet_model_type == 'reduced':
181
        model = get_unet_reduced()
182
    elif unet_model_type == 'extended':
183
        model = get_unet_extended()   
184
    
185
    checkpoint_filepath = 'outputs/' + checkpoint_filename
186
    model.load_weights(checkpoint_filepath)  
187
    model.summary()
188
    
189
    SegmentedVolume = np.zeros((image_rows,image_cols,image_depth))
190
       
191
    img_mask_name = img_dir_name + '_mask.nii.gz'
192
    img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name)
193
194
    img_mask = nib.load(img_mask_name)
195
    img_mask_data = img_mask.get_data()
196
    
197
    # for each slice, extract patches and predict
198
    for iSlice in range(0,256):
199
        mask = img_mask_data[2:254, 2:127, iSlice]
200
201
        if np.sum(mask, axis=(0,1))>0:
202
            print('-' * 30)
203
            print('Slice number: ', iSlice)
204
            label_predicted, patches_training_imgs_2d, rows, cols = create_slice_testing(iSlice, img_dir_name)
205
            imgs_valid_predict = model.predict(patches_training_imgs_2d)
206
            label_predicted_filled = write_slice_predict(imgs_valid_predict, rows, cols)
207
208
            for i in range(0, SegmentedVolume.shape[0]):
209
                for j in range(0, SegmentedVolume.shape[1]):
210
                        if img_mask_data.item((i, j, iSlice)) == 1:
211
                            SegmentedVolume.itemset((i,j,iSlice), label_predicted_filled.item((i, j)))
212
                        else:
213
                            label_predicted_filled.itemset((i, j), 0)
214
        print ('done')
215
216
    # utilize mask to write output
217
    data = SegmentedVolume
218
    img = nib.Nifti1Image(data, np.eye(4))
219
    if num_classes == 3:
220
        img_name = img_dir_name + '_predicted_3class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz'
221
    else:
222
        img_name = img_dir_name + '_predicted_4class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz'
223
    nib.save(img, os.path.join('../data_new', write_path, img_name))
224
    print('-' * 30)
225
226
# main
227
if __name__ == '__main__':
228
    # folder to hold outputs
229
    if 'outputs' not in os.listdir(os.curdir):
230
        os.mkdir('outputs')  
231
        
232
    images_train_dir = os.listdir(train_imgs_path)
233
234
    j=0
235
    for img_dir_name in images_train_dir:
236
        j=j+1
237
        if j:
238
            start_time = time.time()  
239
            print('*'*50)
240
            print('Segmenting: volume {0} / {1} volume images'.format(j, len(images_train_dir)))
241
            # print('-'*30)
242
            if num_classes == 3:
243
                img_name = img_dir_name + '_predicted_3class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz'
244
            else:
245
                img_name = img_dir_name + '_predicted_4class_' + str(patch_size) + '_' + unet_model_type + '_tuned_8925.nii.gz'
246
            print ('Path: ', os.path.join('../data_new', write_path, img_name))
247
            print('*' * 50)
248
            predict(img_dir_name)
249
            end_time = time.time()
250
            elapsed_time = end_time - start_time
251
            print ('Elapsed time is: ', elapsed_time)