Diff of /SynthSeg/evaluate.py [000000] .. [e571d1]

Switch to unified view

a b/SynthSeg/evaluate.py
1
"""
2
If you use this code, please cite one of the SynthSeg papers:
3
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4
5
Copyright 2020 Benjamin Billot
6
7
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8
compliance with the License. You may obtain a copy of the License at
9
https://www.apache.org/licenses/LICENSE-2.0
10
Unless required by applicable law or agreed to in writing, software distributed under the License is
11
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12
implied. See the License for the specific language governing permissions and limitations under the
13
License.
14
"""
15
16
17
# python imports
18
import os
19
import numpy as np
20
from scipy.stats import wilcoxon
21
from scipy.ndimage.morphology import distance_transform_edt
22
23
# third-party imports
24
from ext.lab2im import utils
25
from ext.lab2im import edit_volumes
26
27
28
def fast_dice(x, y, labels):
29
    """Fast implementation of Dice scores.
30
    :param x: input label map
31
    :param y: input label map of the same size as x
32
    :param labels: numpy array of labels to evaluate on
33
    :return: numpy array with Dice scores in the same order as labels.
34
    """
35
36
    assert x.shape == y.shape, 'both inputs should have same size, had {} and {}'.format(x.shape, y.shape)
37
38
    if len(labels) > 1:
39
        # sort labels
40
        labels_sorted = np.sort(labels)
41
42
        # build bins for histograms
43
        label_edges = np.sort(np.concatenate([labels_sorted - 0.1, labels_sorted + 0.1]))
44
        label_edges = np.insert(label_edges, [0, len(label_edges)], [labels_sorted[0] - 0.1, labels_sorted[-1] + 0.1])
45
46
        # compute Dice and re-arrange scores in initial order
47
        hst = np.histogram2d(x.flatten(), y.flatten(), bins=label_edges)[0]
48
        idx = np.arange(start=1, stop=2 * len(labels_sorted), step=2)
49
        dice_score = 2 * np.diag(hst)[idx] / (np.sum(hst, 0)[idx] + np.sum(hst, 1)[idx] + 1e-5)
50
        dice_score = dice_score[np.searchsorted(labels_sorted, labels)]
51
52
    else:
53
        dice_score = dice(x == labels[0], y == labels[0])
54
55
    return dice_score
56
57
58
def dice(x, y):
59
    """Implementation of dice scores for 0/1 numpy array"""
60
    return 2 * np.sum(x * y) / (np.sum(x) + np.sum(y))
61
62
63
def surface_distances(x, y, hausdorff_percentile=None, return_coordinate_max_distance=False):
64
    """Computes the maximum boundary distance (Hausdorff distance), and the average boundary distance of two masks.
65
    :param x: numpy array (boolean or 0/1)
66
    :param y: numpy array (boolean or 0/1)
67
    :param hausdorff_percentile: (optional) percentile (from 0 to 100) for which to compute the Hausdorff distance.
68
    Set this to 100 to compute the real Hausdorff distance (default). Can also be a list, where HD will be computed for
69
    the provided values.
70
    :param return_coordinate_max_distance: (optional) when set to true, the function will return the coordinates of the
71
    voxel with the highest distance (only if hausdorff_percentile=100).
72
    :return: max_dist, mean_dist(, coordinate_max_distance)
73
    max_dist: scalar with HD computed for the given percentile (or list if hausdorff_percentile was given as a list).
74
    mean_dist: scalar with average surface distance
75
    coordinate_max_distance: only returned return_coordinate_max_distance is True."""
76
77
    assert x.shape == y.shape, 'both inputs should have same size, had {} and {}'.format(x.shape, y.shape)
78
    n_dims = len(x.shape)
79
80
    hausdorff_percentile = 100 if hausdorff_percentile is None else hausdorff_percentile
81
    hausdorff_percentile = utils.reformat_to_list(hausdorff_percentile)
82
83
    # crop x and y around ROI
84
    _, crop_x = edit_volumes.crop_volume_around_region(x)
85
    _, crop_y = edit_volumes.crop_volume_around_region(y)
86
87
    # set distances to maximum volume shape if they are not defined
88
    if (crop_x is None) | (crop_y is None):
89
        return max(x.shape), max(x.shape)
90
91
    crop = np.concatenate([np.minimum(crop_x, crop_y)[:n_dims], np.maximum(crop_x, crop_y)[n_dims:]])
92
    x = edit_volumes.crop_volume_with_idx(x, crop)
93
    y = edit_volumes.crop_volume_with_idx(y, crop)
94
95
    # detect edge
96
    x_dist_int = distance_transform_edt(x * 1)
97
    x_edge = (x_dist_int == 1) * 1
98
    y_dist_int = distance_transform_edt(y * 1)
99
    y_edge = (y_dist_int == 1) * 1
100
101
    # calculate distance from edge
102
    x_dist = distance_transform_edt(np.logical_not(x_edge))
103
    y_dist = distance_transform_edt(np.logical_not(y_edge))
104
105
    # find distances from the 2 surfaces
106
    x_dists_to_y = y_dist[x_edge == 1]
107
    y_dists_to_x = x_dist[y_edge == 1]
108
109
    max_dist = list()
110
    coordinate_max_distance = None
111
    for hd_percentile in hausdorff_percentile:
112
113
        # find max distance from the 2 surfaces
114
        if hd_percentile == 100:
115
            max_dist.append(np.max(np.concatenate([x_dists_to_y, y_dists_to_x])))
116
117
            if return_coordinate_max_distance:
118
                indices_x_surface = np.where(x_edge == 1)
119
                idx_max_distance_x = np.where(x_dists_to_y == max_dist)[0]
120
                if idx_max_distance_x.size != 0:
121
                    coordinate_max_distance = np.stack(indices_x_surface).transpose()[idx_max_distance_x]
122
                else:
123
                    indices_y_surface = np.where(y_edge == 1)
124
                    idx_max_distance_y = np.where(y_dists_to_x == max_dist)[0]
125
                    coordinate_max_distance = np.stack(indices_y_surface).transpose()[idx_max_distance_y]
126
127
        # find percentile of max distance
128
        else:
129
            max_dist.append(np.percentile(np.concatenate([x_dists_to_y, y_dists_to_x]), hd_percentile))
130
131
    # find average distance between 2 surfaces
132
    if x_dists_to_y.shape[0] > 0:
133
        x_mean_dist_to_y = np.mean(x_dists_to_y)
134
    else:
135
        x_mean_dist_to_y = max(x.shape)
136
    if y_dists_to_x.shape[0] > 0:
137
        y_mean_dist_to_x = np.mean(y_dists_to_x)
138
    else:
139
        y_mean_dist_to_x = max(x.shape)
140
    mean_dist = (x_mean_dist_to_y + y_mean_dist_to_x) / 2
141
142
    # convert max dist back to scalar if HD only computed for 1 percentile
143
    if len(max_dist) == 1:
144
        max_dist = max_dist[0]
145
146
    # return coordinate of max distance if necessary
147
    if coordinate_max_distance is not None:
148
        return max_dist, mean_dist, coordinate_max_distance
149
    else:
150
        return max_dist, mean_dist
151
152
153
def compute_non_parametric_paired_test(dice_ref, dice_compare, eval_indices=None, alternative='two-sided'):
154
    """Compute non-parametric paired t-tests between two sets of Dice scores.
155
    :param dice_ref: numpy array with Dice scores, rows represent structures, and columns represent subjects.
156
    Taken as reference for one-sided tests.
157
    :param dice_compare: numpy array of the same format as dice_ref.
158
    :param eval_indices: (optional) list or 1d array indicating the row indices of structures to run the tests for.
159
    Default is None, for which p-values are computed for all rows.
160
    :param alternative: (optional) The alternative hypothesis to be tested, can be 'two-sided', 'greater', 'less'.
161
    :return: 1d numpy array, with p-values for all tests on evaluated structures, as well as an additional test for
162
    average scores (last value of the array). The average score is computed only on the evaluation structures.
163
    """
164
165
    # take all rows if indices not specified
166
    if eval_indices is None:
167
        if len(dice_ref.shape) > 1:
168
            eval_indices = np.arange(dice_ref.shape[0])
169
        else:
170
            eval_indices = []
171
172
    # loop over all evaluation label values
173
    pvalues = list()
174
    if len(eval_indices) > 1:
175
        for idx in eval_indices:
176
177
            x = dice_ref[idx, :]
178
            y = dice_compare[idx, :]
179
            _, p = wilcoxon(x, y, alternative=alternative)
180
            pvalues.append(p)
181
182
        # average score
183
        x = np.mean(dice_ref[eval_indices, :], axis=0)
184
        y = np.mean(dice_compare[eval_indices, :], axis=0)
185
        _, p = wilcoxon(x, y, alternative=alternative)
186
        pvalues.append(p)
187
188
    else:
189
        # average score
190
        _, p = wilcoxon(dice_ref, dice_compare, alternative=alternative)
191
        pvalues.append(p)
192
193
    return np.array(pvalues)
194
195
196
def cohens_d(volumes_x, volumes_y):
197
198
    means_x = np.mean(volumes_x, axis=0)
199
    means_y = np.mean(volumes_y, axis=0)
200
201
    var_x = np.var(volumes_x, axis=0)
202
    var_y = np.var(volumes_y, axis=0)
203
204
    n_x = np.shape(volumes_x)[0]
205
    n_y = np.shape(volumes_y)[0]
206
207
    std = np.sqrt(((n_x-1)*var_x + (n_y-1)*var_y) / (n_x + n_y - 2))
208
    cohensd = (means_x - means_y) / std
209
210
    return cohensd
211
212
213
def evaluation(gt_dir,
214
               seg_dir,
215
               label_list,
216
               mask_dir=None,
217
               compute_score_whole_structure=False,
218
               path_dice=None,
219
               path_hausdorff=None,
220
               path_hausdorff_99=None,
221
               path_hausdorff_95=None,
222
               path_mean_distance=None,
223
               crop_margin_around_gt=10,
224
               list_incorrect_labels=None,
225
               list_correct_labels=None,
226
               use_nearest_label=False,
227
               recompute=True,
228
               verbose=True):
229
    """This function computes Dice scores, as well as surface distances, between two sets of labels maps in gt_dir
230
    (ground truth) and seg_dir (typically predictions). Label maps in both folders are matched by sorting order.
231
    The resulting scores are saved at the specified locations.
232
    :param gt_dir: path of directory with gt label maps
233
    :param seg_dir: path of directory with label maps to compare to gt_dir. Matched to gt label maps by sorting order.
234
    :param label_list: list of label values for which to compute evaluation metrics. Can be a sequence, a 1d numpy
235
    array, or the path to such array.
236
    :param mask_dir: (optional) path of directory with masks of areas to ignore for each evaluated segmentation.
237
    Matched to gt label maps by sorting order. Default is None, where nothing is masked.
238
    :param compute_score_whole_structure: (optional) whether to also compute the selected scores for the whole segmented
239
    structure (i.e. scores are computed for a single structure obtained by regrouping all non-zero values). If True, the
240
    resulting scores are added as an extra row to the result matrices. Default is False.
241
    :param path_dice: path where the resulting Dice will be writen as numpy array.
242
    Default is None, where the array is not saved.
243
    :param path_hausdorff: path where the resulting Hausdorff distances will be writen as numpy array (only if
244
    compute_distances is True). Default is None, where the array is not saved.
245
    :param path_hausdorff_99: same as for path_hausdorff but for the 99th percentile of the boundary distance.
246
    :param path_hausdorff_95: same as for path_hausdorff but for the 95th percentile of the boundary distance.
247
    :param path_mean_distance: path where the resulting mean distances will be writen as numpy array (only if
248
    compute_distances is True). Default is None, where the array is not saved.
249
    :param crop_margin_around_gt: (optional) margin by which to crop around the gt volumes, in order to compute the
250
    scores more efficiently. If 0, no cropping is performed.
251
    :param list_incorrect_labels: (optional) this option enables to replace some label values in the maps in seg_dir by
252
    other label values. Can be a list, a 1d numpy array, or the path to such an array.
253
    The incorrect labels can then be replaced either by specified values, or by the nearest value (see below).
254
    :param list_correct_labels: (optional) list of values to correct the labels specified in list_incorrect_labels.
255
    Correct values must have the same order as their corresponding value in list_incorrect_labels.
256
    :param use_nearest_label: (optional) whether to correct the incorrect label values with the nearest labels.
257
    :param recompute: (optional) whether to recompute the already existing results. Default is True.
258
    :param verbose: (optional) whether to print out info about the remaining number of cases.
259
    """
260
261
    # check whether to recompute
262
    compute_dice = not os.path.isfile(path_dice) if (path_dice is not None) else True
263
    compute_hausdorff = not os.path.isfile(path_hausdorff) if (path_hausdorff is not None) else False
264
    compute_hausdorff_99 = not os.path.isfile(path_hausdorff_99) if (path_hausdorff_99 is not None) else False
265
    compute_hausdorff_95 = not os.path.isfile(path_hausdorff_95) if (path_hausdorff_95 is not None) else False
266
    compute_mean_dist = not os.path.isfile(path_mean_distance) if (path_mean_distance is not None) else False
267
    compute_hd = [compute_hausdorff, compute_hausdorff_99, compute_hausdorff_95]
268
269
    if compute_dice | any(compute_hd) | compute_mean_dist | recompute:
270
271
        # get list label maps to compare
272
        path_gt_labels = utils.list_images_in_folder(gt_dir)
273
        path_segs = utils.list_images_in_folder(seg_dir)
274
        path_gt_labels = utils.reformat_to_list(path_gt_labels, length=len(path_segs))
275
        if len(path_gt_labels) != len(path_segs):
276
            print('gt and segmentation folders must have the same amount of label maps.')
277
        if mask_dir is not None:
278
            path_masks = utils.list_images_in_folder(mask_dir)
279
            if len(path_masks) != len(path_segs):
280
                print('not the same amount of masks and segmentations.')
281
        else:
282
            path_masks = [None] * len(path_segs)
283
284
        # load labels list
285
        label_list, _ = utils.get_list_labels(label_list=label_list, labels_dir=gt_dir)
286
        n_labels = len(label_list)
287
        max_label = np.max(label_list) + 1
288
289
        # initialise result matrices
290
        if compute_score_whole_structure:
291
            max_dists = np.zeros((n_labels + 1, len(path_segs), 3))
292
            mean_dists = np.zeros((n_labels + 1, len(path_segs)))
293
            dice_coefs = np.zeros((n_labels + 1, len(path_segs)))
294
        else:
295
            max_dists = np.zeros((n_labels, len(path_segs), 3))
296
            mean_dists = np.zeros((n_labels, len(path_segs)))
297
            dice_coefs = np.zeros((n_labels, len(path_segs)))
298
299
        # loop over segmentations
300
        loop_info = utils.LoopInfo(len(path_segs), 10, 'evaluating', print_time=True)
301
        for idx, (path_gt, path_seg, path_mask) in enumerate(zip(path_gt_labels, path_segs, path_masks)):
302
            if verbose:
303
                loop_info.update(idx)
304
305
            # load gt labels and segmentation
306
            gt_labels = utils.load_volume(path_gt, dtype='int', aff_ref=np.eye(4))
307
            seg = utils.load_volume(path_seg, dtype='int', aff_ref=np.eye(4))
308
            if path_mask is not None:
309
                mask = utils.load_volume(path_mask, dtype='bool', aff_ref=np.eye(4))
310
                gt_labels[mask] = max_label
311
                seg[mask] = max_label
312
313
            # crop images
314
            if crop_margin_around_gt > 0:
315
                gt_labels, cropping = edit_volumes.crop_volume_around_region(gt_labels, margin=crop_margin_around_gt)
316
                seg = edit_volumes.crop_volume_with_idx(seg, cropping)
317
318
            if list_incorrect_labels is not None:
319
                seg = edit_volumes.correct_label_map(seg, list_incorrect_labels, list_correct_labels, use_nearest_label)
320
321
            # compute Dice scores
322
            dice_coefs[:n_labels, idx] = fast_dice(gt_labels, seg, label_list)
323
324
            # compute Dice scores for whole structures
325
            if compute_score_whole_structure:
326
                temp_gt = (gt_labels > 0) * 1
327
                temp_seg = (seg > 0) * 1
328
                dice_coefs[-1, idx] = dice(temp_gt, temp_seg)
329
            else:
330
                temp_gt = temp_seg = None
331
332
            # compute average and Hausdorff distances
333
            if any(compute_hd) | compute_mean_dist:
334
335
                # compute unique label values
336
                unique_gt_labels = np.unique(gt_labels)
337
                unique_seg_labels = np.unique(seg)
338
339
                # compute max/mean surface distances for all labels
340
                for index, label in enumerate(label_list):
341
                    if (label in unique_gt_labels) & (label in unique_seg_labels):
342
                        mask_gt = np.where(gt_labels == label, True, False)
343
                        mask_seg = np.where(seg == label, True, False)
344
                        tmp_max_dists, mean_dists[index, idx] = surface_distances(mask_gt, mask_seg, [100, 99, 95])
345
                        max_dists[index, idx, :] = np.array(tmp_max_dists)
346
                    else:
347
                        mean_dists[index, idx] = max(gt_labels.shape)
348
                        max_dists[index, idx, :] = np.array([max(gt_labels.shape)] * 3)
349
350
                # compute max/mean distances for whole structure
351
                if compute_score_whole_structure:
352
                    tmp_max_dists, mean_dists[-1, idx] = surface_distances(temp_gt, temp_seg, [100, 99, 95])
353
                    max_dists[-1, idx, :] = np.array(tmp_max_dists)
354
355
        # write results
356
        if path_dice is not None:
357
            utils.mkdir(os.path.dirname(path_dice))
358
            np.save(path_dice, dice_coefs)
359
        if path_hausdorff is not None:
360
            utils.mkdir(os.path.dirname(path_hausdorff))
361
            np.save(path_hausdorff, max_dists[..., 0])
362
        if path_hausdorff_99 is not None:
363
            utils.mkdir(os.path.dirname(path_hausdorff_99))
364
            np.save(path_hausdorff_99, max_dists[..., 1])
365
        if path_hausdorff_95 is not None:
366
            utils.mkdir(os.path.dirname(path_hausdorff_95))
367
            np.save(path_hausdorff_95, max_dists[..., 2])
368
        if path_mean_distance is not None:
369
            utils.mkdir(os.path.dirname(path_mean_distance))
370
            np.save(path_mean_distance, mean_dists)