Switch to unified view

a b/src/LiviaNet/startTesting.py
1
""" 
2
Copyright (c) 2016, Jose Dolz .All rights reserved.
3
4
Redistribution and use in source and binary forms, with or without modification,
5
are permitted provided that the following conditions are met:
6
7
    1. Redistributions of source code must retain the above copyright notice,
8
       this list of conditions and the following disclaimer.
9
    2. Redistributions in binary form must reproduce the above copyright notice,
10
       this list of conditions and the following disclaimer in the documentation
11
       and/or other materials provided with the distribution.
12
13
    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14
    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
15
    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
16
    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
17
    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18
    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
20
    OTHER DEALINGS IN THE SOFTWARE.
21
22
Jose Dolz. Dec, 2016.
23
email: jose.dolz.upv@gmail.com
24
LIVIA Department, ETS, Montreal.
25
"""
26
27
import numpy as np
28
import time
29
import os
30
import pdb
31
32
from Modules.General.Evaluation import computeDice
33
from Modules.General.Utils import getImagesSet
34
from Modules.General.Utils import load_model_from_gzip_file
35
from Modules.IO.ImgOperations.imgOp import applyUnpadding
36
from Modules.IO.loadData import load_imagesSinglePatient
37
from Modules.IO.saveData import saveImageAsNifti
38
from Modules.IO.saveData import saveImageAsMatlab
39
from Modules.IO.sampling import *
40
from Modules.Parsers.parsersUtils import parserConfigIni
41
42
43
def segmentVolume(myNetworkModel,
44
                  i_d,
45
                  imageNames_Test,
46
                  names_Test,
47
                  groundTruthNames_Test,
48
                  roiNames_Test,
49
                  imageType,
50
                  padInputImagesBool,
51
                  receptiveField, 
52
                  sampleSize_Test,
53
                  strideVal,
54
                  batch_Size,
55
                  task # Validation (0) or testing (1)
56
                  ):
57
        # Get info from the network model        
58
        networkName        = myNetworkModel.networkName
59
        folderName         = myNetworkModel.folderName
60
        n_classes          = myNetworkModel.n_classes
61
        sampleSize_Test    = myNetworkModel.sampleSize_Test
62
        receptiveField     = myNetworkModel.receptiveField  
63
        outputShape        = myNetworkModel.lastLayer.outputShapeTest[2:] 
64
        batch_Size         = myNetworkModel.batch_Size
65
        padInputImagesBool = True
66
    
67
        # Get half sample size
68
        sampleHalf = []
69
        for h_i in range(3):
70
            sampleHalf.append((receptiveField[h_i]-1)/2)
71
        
72
        # Load the images to segment
73
        [imgSubject,  
74
        gtLabelsImage, 
75
        roi, 
76
        paddingValues] = load_imagesSinglePatient(i_d,
77
                                                  imageNames_Test,
78
                                                  groundTruthNames_Test,
79
                                                  roiNames_Test,
80
                                                  padInputImagesBool,
81
                                                  receptiveField, 
82
                                                  sampleSize_Test,
83
                                                  imageType, 
84
                                                  )
85
                                                  
86
                                  
87
        # Get image dimensions                                                    
88
        imgDims = list(imgSubject.shape)
89
    
90
        [ sampleCoords ] = sampleWholeImage(imgSubject,
91
                                            roi,
92
                                            sampleSize_Test,
93
                                            strideVal,
94
                                            batch_Size
95
                                            )
96
        
97
        numberOfSamples = len(sampleCoords)
98
        sampleID = 0
99
        numberOfBatches = numberOfSamples/batch_Size
100
101
        #The probability-map that will be constructed by the predictions.
102
        probMaps = np.zeros([n_classes]+imgDims, dtype = "float32")
103
        
104
        # Run over all the batches 
105
        for b_i in xrange(numberOfBatches) :
106
                 
107
            # Get samples for batch b_i
108
            
109
            sampleCoords_b = sampleCoords[ b_i*batch_Size : (b_i+1)*batch_Size ]
110
            
111
            [imgSamples] = extractSamples(imgSubject,
112
                                          sampleCoords_b,
113
                                          sampleSize_Test,
114
                                          receptiveField)
115
116
            # Load the data of the batch on the GPU
117
            myNetworkModel.testingData_x.set_value(imgSamples, borrow=True)
118
           
119
            # Call the testing Theano function            
120
            predictions = myNetworkModel.networkModel_Test(0)
121
            
122
            predOutput = predictions[-1]
123
            
124
            # --- Now we can generate the probability maps from the predictions ----
125
            # Run over all the regions
126
            for r_i in xrange(batch_Size) :
127
 
128
                sampleCoords_i = sampleCoords[sampleID]
129
                coords = [ sampleCoords_i[0][0], sampleCoords_i[1][0], sampleCoords_i[2][0] ]
130
131
                # Get the min and max coords
132
                xMin = coords[0] + sampleHalf[0]
133
                xMax = coords[0] + sampleHalf[0] + strideVal[0]
134
135
                yMin = coords[1] + sampleHalf[1]
136
                yMax = coords[1] + sampleHalf[1] + strideVal[1]
137
138
                zMin = coords[2] + sampleHalf[2]
139
                zMax = coords[2] + sampleHalf[2] + strideVal[2]
140
                
141
                probMaps[:,xMin:xMax, yMin:yMax, zMin:zMax] = predOutput[r_i]
142
143
                sampleID += 1
144
            
145
        # Release data
146
        myNetworkModel.testingData_x.set_value(np.zeros([1,1,1,1,1], dtype="float32"))
147
148
        # Segmentation has been done in this point.
149
        
150
        # Now: Save the data
151
        # Get the segmentation from the probability maps ---
152
        segmentationImage = np.argmax(probMaps, axis=0) 
153
        
154
        #Save Result:
155
        npDtypeForPredictedImage = np.dtype(np.int16)
156
        suffixToAdd = "_Segm"
157
 
158
        # Apply unpadding if specified
159
        if padInputImagesBool == True:
160
            segmentationRes = applyUnpadding(segmentationImage, paddingValues)
161
        else:
162
            segmentationRes = segmentationImage
163
164
        # Generate folders to store the model
165
        BASE_DIR = os.getcwd()
166
        path_Temp = os.path.join(BASE_DIR,'outputFiles')
167
168
        # For the predictions
169
        predlFolderName = os.path.join(path_Temp,myNetworkModel.folderName)
170
        predlFolderName = os.path.join(predlFolderName,'Pred')
171
        if task == 0:
172
            predTestFolderName = os.path.join(predlFolderName,'Validation')
173
        else:
174
            predTestFolderName = os.path.join(predlFolderName,'Testing')
175
        
176
        nameToSave = predTestFolderName + '/Segmentation_'+ names_Test[i_d]
177
        
178
        # Save Segmentation image
179
        
180
        print(" ... Saving segmentation result..."),
181
        if imageType == 0: # nifti
182
            imageTypeToSave = np.dtype(np.int16)
183
            saveImageAsNifti(segmentationRes,
184
                             nameToSave,
185
                             imageNames_Test[i_d],
186
                             imageTypeToSave)
187
        else: # Matlab
188
            # Cast to int8 for saving purposes
189
            saveImageAsMatlab(segmentationRes.astype('int8'),
190
                              nameToSave)
191
192
193
        # Save the prob maps for each class (except background)
194
        for c_i in xrange(1, n_classes) :
195
            
196
            
197
            nameToSave = predTestFolderName + '/ProbMap_class_'+ str(c_i) + '_' + names_Test[i_d] 
198
199
            probMapClass = probMaps[c_i,:,:,:]
200
201
            # Apply unpadding if specified
202
            if padInputImagesBool == True:
203
                probMapClassRes = applyUnpadding(probMapClass, paddingValues)
204
            else:
205
                probMapClassRes = probMapClass
206
207
            print(" ... Saving prob map for class {}...".format(str(c_i))),
208
            if imageType == 0: # nifti
209
                imageTypeToSave = np.dtype(np.float32)
210
                saveImageAsNifti(probMapClassRes,
211
                                 nameToSave,
212
                                 imageNames_Test[i_d],
213
                                 imageTypeToSave)
214
            else:
215
                # Cast to float32 for saving purposes
216
                saveImageAsMatlab(probMapClassRes.astype('float32'),
217
                                  nameToSave)
218
219
        # If segmentation done during evaluation, get dice
220
        if task == 0:
221
            print(" ... Computing Dice scores: ")
222
            DiceArray = computeDice(segmentationImage,gtLabelsImage)
223
            for d_i in xrange(len(DiceArray)):
224
                print(" -------------- DSC (Class {}) : {}".format(str(d_i+1),DiceArray[d_i]))
225
226
""" Main segmentation function """
227
def startTesting(networkModelName,
228
                 configIniName
229
                 ) :
230
231
    padInputImagesBool = True # from config ini
232
    print " ******************************************  STARTING SEGMENTATION ******************************************"
233
234
    print " **********************  Starting segmentation **********************"
235
    myParserConfigIni = parserConfigIni()
236
    myParserConfigIni.readConfigIniFile(configIniName,2)
237
    
238
239
    print " -------- Images to segment -------------"
240
241
    print " -------- Reading Images names for segmentation -------------"
242
    
243
    # -- Get list of images used for testing -- #
244
    (imageNames_Test, names_Test) = getImagesSet(myParserConfigIni.imagesFolder,myParserConfigIni.indexesToSegment)  # Images
245
    (groundTruthNames_Test, gt_names_Test) = getImagesSet(myParserConfigIni.GroundTruthFolder,myParserConfigIni.indexesToSegment) # Ground truth
246
    (roiNames_Test, roi_names_Test) = getImagesSet(myParserConfigIni.ROIFolder,myParserConfigIni.indexesToSegment) # ROI
247
248
    # --------------- Load my LiviaNet3D object  --------------- 
249
    print (" ... Loading model from {}".format(networkModelName))
250
    myLiviaNet3D = load_model_from_gzip_file(networkModelName)
251
    print " ... Network architecture successfully loaded...."
252
253
    # Get info from the network model        
254
    networkName        = myLiviaNet3D.networkName
255
    folderName         = myLiviaNet3D.folderName
256
    n_classes          = myLiviaNet3D.n_classes
257
    sampleSize_Test    = myLiviaNet3D.sampleSize_Test
258
    receptiveField     = myLiviaNet3D.receptiveField  
259
    outputShape        = myLiviaNet3D.lastLayer.outputShapeTest[2:] 
260
    batch_Size         = myLiviaNet3D.batch_Size
261
    padInputImagesBool = myParserConfigIni.applyPadding
262
    imageType          = myParserConfigIni.imageTypes
263
    numberImagesToSegment = len(imageNames_Test)
264
    
265
    strideValues = myLiviaNet3D.lastLayer.outputShapeTest[2:]
266
267
    # Run over the images to segment   
268
    for i_d in xrange(numberImagesToSegment) :
269
        print("**********************  Segmenting subject: {} ....total: {}/{}...**********************".format(names_Test[i_d],str(i_d+1),str(numberImagesToSegment)))
270
        
271
        segmentVolume(myLiviaNet3D,
272
                  i_d,
273
                  imageNames_Test,  # Full path
274
                  names_Test,       # Only image name
275
                  groundTruthNames_Test,
276
                  roiNames_Test,
277
                  imageType,
278
                  padInputImagesBool,
279
                  receptiveField, 
280
                  sampleSize_Test,
281
                  strideValues,
282
                  batch_Size,
283
                  1 # Validation (0) or testing (1)
284
                  )
285
                         
286
       
287
    print(" **************************************************************************************************** ")