Diff of /3D/predict.py [000000] .. [c9b969]

Switch to unified view

a b/3D/predict.py
1
from __future__ import print_function
2
3
# import packages
4
from functools import partial
5
import os, time
6
import numpy as np
7
from model import unet_model_3d
8
from keras.models import Model
9
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
10
from keras.optimizers import Adam
11
from keras import callbacks
12
from keras import backend as K
13
from keras.utils import plot_model
14
import nibabel as nib
15
from PIL import Image
16
from sklearn.feature_extraction.image import extract_patches
17
18
# import load data
19
from data_handling import load_train_data, load_validatation_data
20
21
# import configurations
22
import configs
23
24
patch_size = configs.PATCH_SIZE
25
config = dict()
26
config["pool_size"] = (2, 2, 2)  # pool size for the max pooling operations
27
config["image_shape"] = (256, 128, 256)  # This determines what shape the images will be cropped/resampled to.
28
config["input_shape"] = (patch_size, patch_size, patch_size, 1)  # switch to None to train on the whole image (64, 64, 64) (64, 64, 64)
29
config["n_labels"] = 4
30
config["all_modalities"] = ['t1']#]["t1", "t1Gd", "flair", "t2"]
31
config["training_modalities"] = config["all_modalities"]  # change this if you want to only use some of the modalities
32
config["nb_channels"] = len(config["training_modalities"])
33
config["deconvolution"] = False  # if False, will use upsampling instead of deconvolution
34
config["batch_size"] = 8
35
config["n_epochs"] = 500  # cutoff the training after this many epochs
36
config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
37
config["early_stop"] = 20  # training will be stopped after this many epochs without the validation loss improving
38
config["initial_learning_rate"] = 0.0001
39
config["depth"] = configs.DEPTH
40
config["learning_rate_drop"] = 0.5
41
42
extraction_reconstruct_step = configs.EXTRACTION_RECONSTRUCT_STEP
43
image_type = '3d_patches'
44
K.set_image_data_format('channels_last')  # TF dimension ordering in this code
45
46
image_type = configs.IMAGE_TYPE
47
48
# init configs
49
image_rows = configs.VOLUME_ROWS
50
image_cols = configs.VOLUME_COLS
51
image_depth = configs.VOLUME_DEPS
52
num_classes = configs.NUM_CLASSES
53
54
# patch extraction parameters
55
56
BASE = configs.BASE
57
smooth = configs.SMOOTH
58
nb_epochs  = configs.NUM_EPOCHS
59
batch_size  = configs.BATCH_SIZE
60
unet_model_type = configs.MODEL
61
extraction_step = 1
62
63
train_imgs_path = '../data_new/Validation_Set'
64
checkpoint_filename = 'best_3classes_32_85_93_92_default.h5'
65
write_path = 'predict'
66
67
# for each slice estract patches and stack
68
def create_slice_testing(img_dir_name):
69
    # empty matrix to hold patches
70
    patches_training_imgs_3d = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16')
71
    patches_training_gtruth_3d = np.empty(shape=[0, patch_size, patch_size, patch_size, num_classes], dtype='int16')
72
    images_train_dir = os.listdir(train_imgs_path)
73
    j = 0
74
75
    # volume 
76
    img_name = img_dir_name + '_hist.nii.gz'
77
    img_name = os.path.join(train_imgs_path, img_dir_name, img_name)
78
    print(img_name)
79
80
    # mask 
81
    img_mask_name = img_dir_name + '_mask.nii.gz'
82
    img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name)
83
84
    # load volume and mask
85
    img = nib.load(img_name)
86
    img_data = img.get_data()
87
    img_data = np.squeeze(img_data)
88
89
    img_mask = nib.load(img_mask_name)
90
    img_mask_data = img_mask.get_data()
91
    img_mask_data = np.squeeze(img_mask_data)
92
93
    patches_training_imgs_3d_temp = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16')
94
95
    # extract patches of the jth volume image
96
    imgs_patches, rows, cols, depths = extract_3d_patches_one_slice(img_data, img_mask_data)
97
98
    # update database
99
    patches_training_imgs_3d_temp = np.append(patches_training_imgs_3d_temp, imgs_patches, axis=0)
100
101
    patches_training_imgs_3d = np.append(patches_training_imgs_3d, patches_training_imgs_3d_temp, axis=0)
102
103
    j += 1
104
105
    # convert to single precision
106
    patches_training_imgs_3d = patches_training_imgs_3d.astype('float32')
107
    patches_training_imgs_3d = np.expand_dims(patches_training_imgs_3d, axis=4)
108
    return patches_training_imgs_3d, rows, cols, depths
109
110
111
# extract 3D patches
112
def extract_3d_patches_one_slice(img_data, mask_data):
113
    patch_shape = (patch_size, patch_size, patch_size)
114
115
    # empty matrix to hold patches
116
    imgs_patches_per_volume = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16')
117
    mask_patches_per_slice = np.empty(shape=[0, patch_size, patch_size, patch_size], dtype='int16')
118
    STEP = patch_size-1
119
    img_patches = extract_patches(img_data, patch_shape, extraction_reconstruct_step)
120
    mask_patches = extract_patches(mask_data, patch_shape, extraction_reconstruct_step)
121
    
122
    Sum = np.sum(mask_patches, axis=(3, 4, 5))
123
    rows, cols, depths = np.nonzero(Sum)
124
    N = len(rows)
125
    # select non-zero patches index
126
    selected_img_patches = img_patches[rows, cols, depths, :, :, :]
127
128
    # update database
129
    imgs_patches_per_volume = np.append(imgs_patches_per_volume, selected_img_patches, axis=0)
130
    return imgs_patches_per_volume, rows, cols, depths
131
132
# write predicted label to the final result
133
def write_predict(imgs_valid_predict, rows, cols, depths):
134
    label_predicted = np.zeros((image_rows, image_cols, image_depth, num_classes))
135
    label_predicted_filled = label_predicted
136
    label_final = np.zeros((image_rows, image_cols, image_depth))
137
138
    for index in range(0, len(rows)):
139
        print ('Processing patch: ', index + 1, '/', len(rows))
140
        row = rows[index]; col = cols[index]; dep = depths[index]
141
        start_row = row * extraction_reconstruct_step
142
        start_col = col * extraction_reconstruct_step
143
        start_dep = dep * extraction_reconstruct_step
144
        patch_volume = imgs_valid_predict[index,:,:,:,:]
145
        for i in range (0,patch_size):
146
            for j in range(0, patch_size):
147
                for k in range(0, patch_size):
148
                    prob_class0_new = patch_volume[i][j][k][0]
149
                    prob_class1_new = patch_volume[i][j][k][1]
150
                    prob_class2_new = patch_volume[i][j][k][2]
151
                    prob_class3_new = patch_volume[i][j][k][3]
152
                    
153
                    label_predicted_filled[start_row + i][start_col + j][start_dep + k][0] = prob_class0_new
154
                    label_predicted_filled[start_row + i][start_col + j][start_dep + k][1] = prob_class1_new
155
                    label_predicted_filled[start_row + i][start_col + j][start_dep + k][2] = prob_class2_new
156
                    label_predicted_filled[start_row + i][start_col + j][start_dep + k][3] = prob_class3_new                    
157
158
    for i in range(0, 256):
159
        for j in range(0, 128):
160
            for k in range(0, 256):
161
                prob_class0 = label_predicted_filled[i][j][k][0]
162
                prob_class1 = label_predicted_filled[i][j][k][1]
163
                prob_class2 = label_predicted_filled[i][j][k][2]
164
                prob_class3 = label_predicted_filled[i][j][k][3]
165
166
                prob_max = max(prob_class0, prob_class1, prob_class2, prob_class3)
167
                if prob_class0 == prob_max:
168
                    label_final[i][j][k] = 0
169
                elif prob_class1 == prob_max:
170
                    label_final[i][j][k] = 1
171
                elif prob_class2 == prob_max:
172
                    label_final[i][j][k] = 2
173
                else:
174
                    label_final[i][j][k] = 3
175
176
    return label_final
177
178
# predict function
179
def predict(img_dir_name):
180
   # create a model
181
    model = unet_model_3d(input_shape=config["input_shape"],
182
                                depth=config["depth"],
183
                                pool_size=config["pool_size"],
184
                                n_labels=config["n_labels"],
185
                                initial_learning_rate=config["initial_learning_rate"],
186
                                deconvolution=config["deconvolution"])
187
188
    model.summary()  
189
    
190
    checkpoint_filepath = 'outputs/' + checkpoint_filename
191
    model.load_weights(checkpoint_filepath)
192
    
193
    SegmentedVolume = np.zeros((image_rows,image_cols,image_depth))
194
    
195
    
196
    img_mask_name = img_dir_name + '_mask.nii.gz'
197
    img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name)
198
199
    img_mask = nib.load(img_mask_name)
200
    img_mask_data = img_mask.get_data()
201
202
    patches_training_imgs_3d, rows, cols, depths = create_slice_testing(img_dir_name)
203
    imgs_valid_predict = model.predict(patches_training_imgs_3d)
204
    label_final = write_predict(imgs_valid_predict, rows, cols, depths)
205
206
    for i in range(0, SegmentedVolume.shape[0]):
207
        for j in range(0, SegmentedVolume.shape[1]):
208
            for k in range(0, SegmentedVolume.shape[2]):
209
                if img_mask_data.item((i, j, k)) == 1:
210
                    SegmentedVolume.itemset((i,j,k), label_final.item((i, j,k)))
211
                else:
212
                    label_final.itemset((i, j,k), 0)
213
214
    print ('done')
215
216
217
218
    data = SegmentedVolume
219
    img = nib.Nifti1Image(data, np.eye(4))
220
    if num_classes == 3:
221
        img_name = img_dir_name + '_predicted_3class_unet3d.nii.gz'
222
    else:
223
        img_name = img_dir_name + '_predicted_4class_unet3d_extract12_hist15.nii.gz'
224
    nib.save(img, os.path.join('../data_new', write_path, img_name))
225
    print('-' * 30)
226
227
    # main
228
if __name__ == '__main__':
229
    # folder to hold outputs
230
    if 'outputs' not in os.listdir(os.curdir):
231
        os.mkdir('outputs')  
232
        
233
    images_train_dir = os.listdir(train_imgs_path)
234
235
    j=0
236
    
237
    for img_dir_name in images_train_dir:
238
        j=j+1
239
        if j:
240
            start_time = time.time()
241
   
242
            print('*'*50)
243
            print('Segmenting: volume {0} / {1} volume images'.format(j, len(images_train_dir)))
244
            # print('-'*30)
245
            if num_classes == 3:
246
                img_name = img_dir_name + '_predicted_3class_unet3d.nii.gz'
247
            else:
248
                img_name = img_dir_name + '_predicted_4class_unet3d_extract12_hist15.nii.gz'
249
            print ('Path: ', os.path.join('../data_new', write_path, img_name))
250
            print('*' * 50)
251
            predict(img_dir_name)
252
            end_time = time.time()
253
            elapsed_time = end_time - start_time
254
            print ('Elapsed time is: ', elapsed_time)