Diff of /predict.py [000000] .. [beb348]

Switch to unified view

a b/predict.py
1
import numpy as np
2
import random
3
from glob import glob
4
import os
5
import SimpleITK as sitk
6
from evaluation_metrics import *
7
from model import Unet_model
8
9
10
11
12
class Prediction(object):
13
14
    def __init__(self, batch_size_test,load_model_path):
15
16
        self.batch_size_test=batch_size_test
17
        unet=Unet_model(img_shape=(240,240,4),load_model_weights=load_model_path)
18
        self.model=unet.model
19
        print ('U-net CNN compiled!\n')
20
21
22
    def predict_volume(self, filepath_image,show):
23
24
        '''
25
        segment the input volume
26
        INPUT   (1) str 'filepath_image': filepath of the volume to predict 
27
                (2) bool 'show': True to ,
28
        OUTPUt  (1) np array of the predicted volume
29
                (2) np array of the corresping ground truth
30
        '''
31
32
        #read the volume
33
        flair = glob( filepath_image + '/*_flair.nii.gz')
34
        t2 = glob( filepath_image + '/*_t2.nii.gz')
35
        gt = glob( filepath_image + '/*_seg.nii.gz')
36
        t1s = glob( filepath_image + '/*_t1.nii.gz')
37
        t1c = glob( filepath_image + '/*_t1ce.nii.gz')
38
        t1=[scan for scan in t1s if scan not in t1c]
39
        if (len(flair)+len(t2)+len(gt)+len(t1)+len(t1c))<5:
40
            print("there is a problem here!!! the problem lies in this patient :")
41
        scans_test = [flair[0], t1[0], t1c[0], t2[0], gt[0]]
42
        test_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_test[i])) for i in range(len(scans_test))]
43
44
45
        test_im=np.array(test_im).astype(np.float32)
46
        test_image = test_im[0:4]
47
        gt=test_im[-1]
48
        gt[gt==4]=3
49
50
        #normalize each slice following the same scheme used for training
51
        test_image=self.norm_slices(test_image)
52
        
53
        #transform teh data to channels_last keras format
54
        test_image = test_image.swapaxes(0,1)
55
        test_image=np.transpose(test_image,(0,2,3,1))
56
57
        if show:
58
            verbose=1
59
        else:
60
            verbose=0
61
        # predict classes of each pixel based on the model
62
        prediction = self.model.predict(test_image,batch_size=self.batch_size_test,verbose=verbose)   
63
        prediction = np.argmax(prediction, axis=-1)
64
        prediction=prediction.astype(np.uint8)
65
        #reconstruct the initial target values .i.e. 0,1,2,4 for prediction and ground truth
66
        prediction[prediction==3]=4
67
        gt[gt==3]=4
68
        
69
        return np.array(prediction),np.array(gt)
70
71
72
73
    def evaluate_segmented_volume(self, filepath_image,save,show,save_path):
74
        '''
75
        computes the evaluation metrics on the segmented volume
76
        INPUT   (1) str 'filepath_image': filepath to test image for segmentation, including file extension
77
                (2) bool 'save': whether to save to disk or not
78
                (3) bool 'show': If true, prints the evaluation metrics
79
        OUTPUT np array of all evaluation metrics
80
        '''
81
        
82
        predicted_images,gt= self.predict_volume(filepath_image,show)
83
84
        if save:
85
            tmp=sitk.GetImageFromArray(predicted_images)
86
            sitk.WriteImage ( tmp,'predictions/{}.nii.gz'.format(save_path) )
87
88
        #compute the evaluation metrics 
89
        Dice_complete=DSC_whole(predicted_images,gt)
90
        Dice_enhancing=DSC_en(predicted_images,gt)
91
        Dice_core=DSC_core(predicted_images,gt)
92
93
        Sensitivity_whole=sensitivity_whole(predicted_images,gt)
94
        Sensitivity_en=sensitivity_en(predicted_images,gt)
95
        Sensitivity_core=sensitivity_core(predicted_images,gt)
96
        
97
98
        Specificity_whole=specificity_whole(predicted_images,gt)
99
        Specificity_en=specificity_en(predicted_images,gt)
100
        Specificity_core=specificity_core(predicted_images,gt)
101
102
103
        Hausdorff_whole=hausdorff_whole(predicted_images,gt)
104
        Hausdorff_en=hausdorff_en(predicted_images,gt)
105
        Hausdorff_core=hausdorff_core(predicted_images,gt)
106
107
        if show:
108
            print("************************************************************")
109
            print("Dice complete tumor score : {:0.4f}".format(Dice_complete))
110
            print("Dice core tumor score (tt sauf vert): {:0.4f}".format(Dice_core))
111
            print("Dice enhancing tumor score (jaune):{:0.4f} ".format(Dice_enhancing))
112
            print("**********************************************")
113
            print("Sensitivity complete tumor score : {:0.4f}".format(Sensitivity_whole))
114
            print("Sensitivity core tumor score (tt sauf vert): {:0.4f}".format(Sensitivity_core))
115
            print("Sensitivity enhancing tumor score (jaune):{:0.4f} ".format(Sensitivity_en))
116
            print("***********************************************")
117
            print("Specificity complete tumor score : {:0.4f}".format(Specificity_whole))
118
            print("Specificity core tumor score (tt sauf vert): {:0.4f}".format(Specificity_core))
119
            print("Specificity enhancing tumor score (jaune):{:0.4f} ".format(Specificity_en))
120
            print("***********************************************")
121
            print("Hausdorff complete tumor score : {:0.4f}".format(Hausdorff_whole))
122
            print("Hausdorff core tumor score (tt sauf vert): {:0.4f}".format(Hausdorff_core))
123
            print("Hausdorff enhancing tumor score (jaune):{:0.4f} ".format(Hausdorff_en))
124
            print("***************************************************************\n\n")
125
126
        return np.array((Dice_complete,Dice_core,Dice_enhancing,Sensitivity_whole,Sensitivity_core,Sensitivity_en,Specificity_whole,Specificity_core,Specificity_en,Hausdorff_whole,Hausdorff_core,Hausdorff_en))#))
127
    
128
129
    def predict_multiple_volumes (self, filepath_volumes,save,show):
130
131
        results,Ids=[],[]
132
        for patient in filepath_volumes:
133
            tmp1=patient.split('/')
134
            print("Volume ID: " ,tmp1[-2]+'/'+tmp1[-1])
135
            tmp=self.evaluate_segmented_volume(patient,save=save,show=show,save_path=os.path.basename(patient))
136
            #save the results of each volume
137
            results.append(tmp)
138
            #save each ID for later use
139
            Ids.append(str(tmp1[-2]+'/'+tmp1[-1]))
140
141
        res=np.array(results)     
142
        print("mean : ",np.mean(res,axis=0))
143
        print("std : ",np.std(res,axis=0))
144
        print("median : ",np.median(res,axis=0))
145
        print("25 quantile : ",np.percentile(res,25,axis=0))
146
        print("75 quantile : ",np.percentile(res,75,axis=0))
147
        print("max : ",np.max(res,axis=0))
148
        print("min : ",np.min(res,axis=0))
149
150
        np.savetxt('Results.out', res)
151
        np.savetxt('Volumes_ID.out', Ids,fmt='%s')
152
153
154
    def norm_slices(self,slice_not):
155
        '''
156
            normalizes each slice, excluding gt
157
            subtracts mean and div by std dev for each slice
158
            clips top and bottom one percent of pixel intensities
159
        '''
160
        normed_slices = np.zeros(( 4,155, 240, 240))
161
        for slice_ix in range(4):
162
            normed_slices[slice_ix] = slice_not[slice_ix]
163
            for mode_ix in range(155):
164
                normed_slices[slice_ix][mode_ix] = self._normalize(slice_not[slice_ix][mode_ix])
165
166
        return normed_slices    
167
168
169
    def _normalize(self,slice):
170
171
        b = np.percentile(slice, 99)
172
        t = np.percentile(slice, 1)
173
        slice = np.clip(slice, t, b)
174
        image_nonzero = slice[np.nonzero(slice)]
175
        
176
        if np.std(slice)==0 or np.std(image_nonzero) == 0:
177
            return slice
178
        else:
179
            tmp= (slice - np.mean(image_nonzero)) / np.std(image_nonzero)
180
            tmp[tmp==tmp.min()]=-9
181
            return tmp
182
183
184
185
186
if __name__ == "__main__":
187
188
    #set arguments
189
    model_to_load="models_saved/ResUnet.04_0.646.hdf5" 
190
    #paths for the testing data
191
    path_HGG = glob('Brats2017/Brats17TrainingData/HGG/**')
192
    path_LGG = glob('Brats2017/Brats17TrainingData/LGG/**')
193
194
    test_path=path_HGG+path_LGG
195
    np.random.seed(2022)
196
    np.random.shuffle(test_path)
197
198
    #compile the model
199
    brain_seg_pred = Prediction(batch_size_test=2 ,load_model_path=model_to_load)
200
201
    #predicts each volume and save the results in np array
202
    brain_seg_pred.predict_multiple_volumes(test_path[200:290],save=False,show=True)
203
    
204
205