Diff of /evaluation.py [000000] .. [8ff467]

Switch to unified view

a b/evaluation.py
1
# -*- coding: utf-8 -*-
2
3
import difflib
4
import numpy as np
5
import os
6
import SimpleITK as sitk
7
import scipy.spatial
8
9
# Set the path to the source data (e.g. the training data for self-testing)
10
# and the output directory of that subject
11
testDir        = 'evaluation' # For example: '/input/2'
12
participantDir = 'evaluation' # For example: '/output/2'
13
14
15
labels = {1: 'Cortical gray matter',
16
          2: 'Basal ganglia',
17
          3: 'White matter',
18
          4: 'White matter lesions',
19
          5: 'Cerebrospinal fluid in the extracerebral space',
20
          6: 'Ventricles',
21
          7: 'Cerebellum',
22
          8: 'Brain stem',
23
          # The two labels below are ignored:
24
          #9: 'Infarction',
25
          #10: 'Other',
26
          }
27
28
29
def do():
30
    """Main function"""    
31
    resultFilename = getResultFilename(participantDir)  
32
        
33
    testImage, resultImage = getImages(os.path.join(testDir, 'segm.nii.gz'), resultFilename)
34
    
35
    dsc = getDSC(testImage, resultImage)
36
    h95 = getHausdorff(testImage, resultImage)
37
    vs  = getVS(testImage, resultImage)
38
    
39
    print('Dice',                dsc,       '(higher is better, max=1)')
40
    print('HD',                  h95, 'mm',  '(lower is better, min=0)')
41
    print('VS',                   vs,       '(higher is better, max=1)')
42
    
43
    
44
    
45
def getResultFilename(participantDir):
46
    """Find the filename of the result image.
47
    
48
    This should be result.nii.gz or result.nii. If these files are not present,
49
    it tries to find the closest filename."""
50
    files = os.listdir(participantDir)
51
    
52
    if not files:
53
        raise Exception("No results in "+ participantDir)
54
    
55
    resultFilename = None
56
    if 'result.nii.gz' in files:
57
        resultFilename = os.path.join(participantDir, 'result.nii.gz')
58
    elif 'result.nii' in files:
59
        resultFilename = os.path.join(participantDir, 'result.nii')
60
    else:
61
        # Find the filename that is closest to 'result.nii.gz'
62
        maxRatio = -1
63
        for f in files:
64
            currentRatio = difflib.SequenceMatcher(a = f, b = 'result.nii.gz').ratio()
65
            
66
            if currentRatio > maxRatio:
67
                resultFilename = os.path.join(participantDir, f)
68
                maxRatio = currentRatio
69
                
70
    return resultFilename
71
    
72
73
def getImages(testFilename, resultFilename):
74
    """Return the test and result images, thresholded and pathology masked."""
75
    testImage   = sitk.ReadImage(testFilename)
76
    resultImage = sitk.ReadImage(resultFilename)
77
    
78
    # Check for equality
79
    assert testImage.GetSize() == resultImage.GetSize()
80
    
81
    # Get meta data from the test-image, needed for some sitk methods that check this
82
    resultImage.CopyInformation(testImage)    
83
    
84
    # Remove pathology from the test and result images, since we don't evaluate on that
85
    pathologyImage = sitk.BinaryThreshold(testImage, 9, 11, 0, 1)  # pathology == 9 or 10
86
    
87
    maskedTestImage   = sitk.Mask(testImage,   pathologyImage)     # tissue    == 1 --  8
88
    maskedResultImage = sitk.Mask(resultImage, pathologyImage)
89
    
90
    # Force integer
91
    if not 'integer' in maskedResultImage.GetPixelIDTypeAsString():
92
        maskedResultImage = sitk.Cast(maskedResultImage, sitk.sitkUInt8)
93
            
94
    return maskedTestImage, maskedResultImage
95
    
96
    
97
def getDSC(testImage, resultImage):    
98
    """Compute the Dice Similarity Coefficient."""        
99
    dsc = dict()
100
    for k in labels.keys():
101
        testArray   = sitk.GetArrayFromImage(sitk.BinaryThreshold(  testImage, k, k, 1, 0)).flatten()
102
        resultArray = sitk.GetArrayFromImage(sitk.BinaryThreshold(resultImage, k, k, 1, 0)).flatten()
103
        
104
        # similarity = 1.0 - dissimilarity
105
        # scipy.spatial.distance.dice raises a ZeroDivisionError if both arrays contain only zeros.
106
        try:
107
            dsc[k] = 1.0 - scipy.spatial.distance.dice(testArray, resultArray)
108
        except ZeroDivisionError:
109
            dsc[k] = None
110
    
111
    return dsc
112
113
        
114
def getHausdorff(testImage, resultImage):
115
    """Compute the 95% Hausdorff distance."""    
116
    hd = dict()
117
    for k in labels.keys():
118
        lTestImage   = sitk.BinaryThreshold(  testImage, k, k, 1, 0)
119
        lResultImage = sitk.BinaryThreshold(resultImage, k, k, 1, 0)
120
        
121
        # Hausdorff distance is only defined when something is detected
122
        statistics = sitk.StatisticsImageFilter()
123
        statistics.Execute(lTestImage)
124
        lTestSum = statistics.GetSum()
125
        statistics.Execute(lResultImage)
126
        lResultSum = statistics.GetSum()
127
        if lTestSum == 0 or lResultSum == 0:
128
            hd[k] = None
129
            continue
130
                                
131
        # Edge detection is done by ORIGINAL - ERODED, keeping the outer boundaries of lesions. Erosion is performed in 2D
132
        eTestImage   = sitk.BinaryErode(lTestImage, (1,1,0))
133
        eResultImage = sitk.BinaryErode(lResultImage, (1,1,0))
134
        
135
        hTestImage   = sitk.Subtract(lTestImage, eTestImage)
136
        hResultImage = sitk.Subtract(lResultImage, eResultImage)    
137
        
138
        hTestArray   = sitk.GetArrayFromImage(hTestImage)
139
        hResultArray = sitk.GetArrayFromImage(hResultImage)   
140
            
141
        # Convert voxel location to world coordinates. Use the coordinate system of the test image
142
        # np.nonzero   = elements of the boundary in numpy order (zyx)
143
        # np.flipud    = elements in xyz order
144
        # np.transpose = create tuples (x,y,z)
145
        # testImage.TransformIndexToPhysicalPoint converts (xyz) to world coordinates (in mm)
146
        # (Simple)ITK does not accept all Numpy arrays; therefore we need to convert the coordinate tuples into a Python list before passing them to TransformIndexToPhysicalPoint().
147
        testCoordinates   = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hTestArray) ))]
148
        resultCoordinates = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hResultArray) ))]
149
                
150
        # Use a kd-tree for fast spatial search
151
        def getDistancesFromAtoB(a, b):    
152
            kdTree = scipy.spatial.KDTree(a, leafsize=100)
153
            return kdTree.query(b, k=1, eps=0, p=2)[0]
154
        
155
        # Compute distances from test to result and vice versa. 
156
        dTestToResult = getDistancesFromAtoB(testCoordinates, resultCoordinates)
157
        dResultToTest = getDistancesFromAtoB(resultCoordinates, testCoordinates)
158
        hd[k] = max(np.percentile(dTestToResult, 95), np.percentile(dResultToTest, 95))
159
        
160
    return hd
161
162
163
def getVS(testImage, resultImage):   
164
    """Volume similarity.
165
    
166
    VS = 1 - abs(A - B) / (A + B)
167
    
168
    A = ground truth in ML
169
    B = participant segmentation in ML
170
    """    
171
    # Compute statistics of both images
172
    testStatistics   = sitk.StatisticsImageFilter()
173
    resultStatistics = sitk.StatisticsImageFilter()
174
    
175
    vs = dict()
176
    for k in labels.keys():
177
        testStatistics.Execute(sitk.BinaryThreshold(testImage, k, k, 1, 0))
178
        resultStatistics.Execute(sitk.BinaryThreshold(resultImage, k, k, 1, 0))
179
        
180
        numerator = abs(testStatistics.GetSum() - resultStatistics.GetSum())
181
        denominator = testStatistics.GetSum() + resultStatistics.GetSum()               
182
        
183
        if denominator > 0:        
184
            vs[k] = 1 - float(numerator) / denominator
185
        else:
186
            vs[k] = None
187
        
188
    return vs
189
    
190
    
191
if __name__ == "__main__":
192
    do()