Diff of /tests/test_metrics.py [000000] .. [4e96d3]

Switch to unified view

a b/tests/test_metrics.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import numpy as np
3
4
from mmseg.core.evaluation import (eval_metrics, mean_dice, mean_fscore,
5
                                   mean_iou)
6
from mmseg.core.evaluation.metrics import f_score
7
8
9
def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
10
    """Intersection over Union
11
       Args:
12
           pred_label (np.ndarray): 2D predict map
13
           label (np.ndarray): label 2D label map
14
           num_classes (int): number of categories
15
           ignore_index (int): index ignore in evaluation
16
       """
17
18
    mask = (label != ignore_index)
19
    pred_label = pred_label[mask]
20
    label = label[mask]
21
22
    n = num_classes
23
    inds = n * label + pred_label
24
25
    mat = np.bincount(inds, minlength=n**2).reshape(n, n)
26
27
    return mat
28
29
30
# This func is deprecated since it's not memory efficient
31
def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index):
32
    num_imgs = len(results)
33
    assert len(gt_seg_maps) == num_imgs
34
    total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
35
    for i in range(num_imgs):
36
        mat = get_confusion_matrix(
37
            results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
38
        total_mat += mat
39
    all_acc = np.diag(total_mat).sum() / total_mat.sum()
40
    acc = np.diag(total_mat) / total_mat.sum(axis=1)
41
    iou = np.diag(total_mat) / (
42
        total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat))
43
44
    return all_acc, acc, iou
45
46
47
# This func is deprecated since it's not memory efficient
48
def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index):
49
    num_imgs = len(results)
50
    assert len(gt_seg_maps) == num_imgs
51
    total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
52
    for i in range(num_imgs):
53
        mat = get_confusion_matrix(
54
            results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
55
        total_mat += mat
56
    all_acc = np.diag(total_mat).sum() / total_mat.sum()
57
    acc = np.diag(total_mat) / total_mat.sum(axis=1)
58
    dice = 2 * np.diag(total_mat) / (
59
        total_mat.sum(axis=1) + total_mat.sum(axis=0))
60
61
    return all_acc, acc, dice
62
63
64
# This func is deprecated since it's not memory efficient
65
def legacy_mean_fscore(results,
66
                       gt_seg_maps,
67
                       num_classes,
68
                       ignore_index,
69
                       beta=1):
70
    num_imgs = len(results)
71
    assert len(gt_seg_maps) == num_imgs
72
    total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
73
    for i in range(num_imgs):
74
        mat = get_confusion_matrix(
75
            results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
76
        total_mat += mat
77
    all_acc = np.diag(total_mat).sum() / total_mat.sum()
78
    recall = np.diag(total_mat) / total_mat.sum(axis=1)
79
    precision = np.diag(total_mat) / total_mat.sum(axis=0)
80
    fv = np.vectorize(f_score)
81
    fscore = fv(precision, recall, beta=beta)
82
83
    return all_acc, recall, precision, fscore
84
85
86
def test_metrics():
87
    pred_size = (10, 30, 30)
88
    num_classes = 19
89
    ignore_index = 255
90
    results = np.random.randint(0, num_classes, size=pred_size)
91
    label = np.random.randint(0, num_classes, size=pred_size)
92
93
    # Test the availability of arg: ignore_index.
94
    label[:, 2, 5:10] = ignore_index
95
96
    # Test the correctness of the implementation of mIoU calculation.
97
    ret_metrics = eval_metrics(
98
        results, label, num_classes, ignore_index, metrics='mIoU')
99
    all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
100
        'IoU']
101
    all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
102
                                              ignore_index)
103
    assert all_acc == all_acc_l
104
    assert np.allclose(acc, acc_l)
105
    assert np.allclose(iou, iou_l)
106
    # Test the correctness of the implementation of mDice calculation.
107
    ret_metrics = eval_metrics(
108
        results, label, num_classes, ignore_index, metrics='mDice')
109
    all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
110
        'Dice']
111
    all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
112
                                                ignore_index)
113
    assert all_acc == all_acc_l
114
    assert np.allclose(acc, acc_l)
115
    assert np.allclose(dice, dice_l)
116
    # Test the correctness of the implementation of mDice calculation.
117
    ret_metrics = eval_metrics(
118
        results, label, num_classes, ignore_index, metrics='mFscore')
119
    all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
120
        'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
121
    all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
122
        results, label, num_classes, ignore_index)
123
    assert all_acc == all_acc_l
124
    assert np.allclose(recall, recall_l)
125
    assert np.allclose(precision, precision_l)
126
    assert np.allclose(fscore, fscore_l)
127
    # Test the correctness of the implementation of joint calculation.
128
    ret_metrics = eval_metrics(
129
        results,
130
        label,
131
        num_classes,
132
        ignore_index,
133
        metrics=['mIoU', 'mDice', 'mFscore'])
134
    all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
135
        'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
136
            'Dice'], ret_metrics['Precision'], ret_metrics[
137
                'Recall'], ret_metrics['Fscore']
138
    assert all_acc == all_acc_l
139
    assert np.allclose(acc, acc_l)
140
    assert np.allclose(iou, iou_l)
141
    assert np.allclose(dice, dice_l)
142
    assert np.allclose(precision, precision_l)
143
    assert np.allclose(recall, recall_l)
144
    assert np.allclose(fscore, fscore_l)
145
146
    # Test the correctness of calculation when arg: num_classes is larger
147
    # than the maximum value of input maps.
148
    results = np.random.randint(0, 5, size=pred_size)
149
    label = np.random.randint(0, 4, size=pred_size)
150
    ret_metrics = eval_metrics(
151
        results,
152
        label,
153
        num_classes,
154
        ignore_index=255,
155
        metrics='mIoU',
156
        nan_to_num=-1)
157
    all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
158
        'IoU']
159
    assert acc[-1] == -1
160
    assert iou[-1] == -1
161
162
    ret_metrics = eval_metrics(
163
        results,
164
        label,
165
        num_classes,
166
        ignore_index=255,
167
        metrics='mDice',
168
        nan_to_num=-1)
169
    all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
170
        'Dice']
171
    assert acc[-1] == -1
172
    assert dice[-1] == -1
173
174
    ret_metrics = eval_metrics(
175
        results,
176
        label,
177
        num_classes,
178
        ignore_index=255,
179
        metrics='mFscore',
180
        nan_to_num=-1)
181
    all_acc, precision, recall, fscore = ret_metrics['aAcc'], ret_metrics[
182
        'Precision'], ret_metrics['Recall'], ret_metrics['Fscore']
183
    assert precision[-1] == -1
184
    assert recall[-1] == -1
185
    assert fscore[-1] == -1
186
187
    ret_metrics = eval_metrics(
188
        results,
189
        label,
190
        num_classes,
191
        ignore_index=255,
192
        metrics=['mDice', 'mIoU', 'mFscore'],
193
        nan_to_num=-1)
194
    all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
195
        'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
196
            'Dice'], ret_metrics['Precision'], ret_metrics[
197
                'Recall'], ret_metrics['Fscore']
198
    assert acc[-1] == -1
199
    assert dice[-1] == -1
200
    assert iou[-1] == -1
201
    assert precision[-1] == -1
202
    assert recall[-1] == -1
203
    assert fscore[-1] == -1
204
205
    # Test the bug which is caused by torch.histc.
206
    # torch.histc:  https://pytorch.org/docs/stable/generated/torch.histc.html
207
    # When the arg:bins is set to be same as arg:max,
208
    # some channels of mIoU may be nan.
209
    results = np.array([np.repeat(31, 59)])
210
    label = np.array([np.arange(59)])
211
    num_classes = 59
212
    ret_metrics = eval_metrics(
213
        results, label, num_classes, ignore_index=255, metrics='mIoU')
214
    all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
215
        'IoU']
216
    assert not np.any(np.isnan(iou))
217
218
219
def test_mean_iou():
220
    pred_size = (10, 30, 30)
221
    num_classes = 19
222
    ignore_index = 255
223
    results = np.random.randint(0, num_classes, size=pred_size)
224
    label = np.random.randint(0, num_classes, size=pred_size)
225
    label[:, 2, 5:10] = ignore_index
226
    ret_metrics = mean_iou(results, label, num_classes, ignore_index)
227
    all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
228
        'IoU']
229
    all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
230
                                              ignore_index)
231
    assert all_acc == all_acc_l
232
    assert np.allclose(acc, acc_l)
233
    assert np.allclose(iou, iou_l)
234
235
    results = np.random.randint(0, 5, size=pred_size)
236
    label = np.random.randint(0, 4, size=pred_size)
237
    ret_metrics = mean_iou(
238
        results, label, num_classes, ignore_index=255, nan_to_num=-1)
239
    all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
240
        'IoU']
241
    assert acc[-1] == -1
242
    assert acc[-1] == -1
243
244
245
def test_mean_dice():
246
    pred_size = (10, 30, 30)
247
    num_classes = 19
248
    ignore_index = 255
249
    results = np.random.randint(0, num_classes, size=pred_size)
250
    label = np.random.randint(0, num_classes, size=pred_size)
251
    label[:, 2, 5:10] = ignore_index
252
    ret_metrics = mean_dice(results, label, num_classes, ignore_index)
253
    all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
254
        'Dice']
255
    all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
256
                                                ignore_index)
257
    assert all_acc == all_acc_l
258
    assert np.allclose(acc, acc_l)
259
    assert np.allclose(iou, dice_l)
260
261
    results = np.random.randint(0, 5, size=pred_size)
262
    label = np.random.randint(0, 4, size=pred_size)
263
    ret_metrics = mean_dice(
264
        results, label, num_classes, ignore_index=255, nan_to_num=-1)
265
    all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
266
        'Dice']
267
    assert acc[-1] == -1
268
    assert dice[-1] == -1
269
270
271
def test_mean_fscore():
272
    pred_size = (10, 30, 30)
273
    num_classes = 19
274
    ignore_index = 255
275
    results = np.random.randint(0, num_classes, size=pred_size)
276
    label = np.random.randint(0, num_classes, size=pred_size)
277
    label[:, 2, 5:10] = ignore_index
278
    ret_metrics = mean_fscore(results, label, num_classes, ignore_index)
279
    all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
280
        'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
281
    all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
282
        results, label, num_classes, ignore_index)
283
    assert all_acc == all_acc_l
284
    assert np.allclose(recall, recall_l)
285
    assert np.allclose(precision, precision_l)
286
    assert np.allclose(fscore, fscore_l)
287
288
    ret_metrics = mean_fscore(
289
        results, label, num_classes, ignore_index, beta=2)
290
    all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
291
        'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
292
    all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
293
        results, label, num_classes, ignore_index, beta=2)
294
    assert all_acc == all_acc_l
295
    assert np.allclose(recall, recall_l)
296
    assert np.allclose(precision, precision_l)
297
    assert np.allclose(fscore, fscore_l)
298
299
    results = np.random.randint(0, 5, size=pred_size)
300
    label = np.random.randint(0, 4, size=pred_size)
301
    ret_metrics = mean_fscore(
302
        results, label, num_classes, ignore_index=255, nan_to_num=-1)
303
    all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
304
        'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
305
    assert recall[-1] == -1
306
    assert precision[-1] == -1
307
    assert fscore[-1] == -1
308
309
310
def test_filename_inputs():
311
    import cv2
312
    import tempfile
313
314
    def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
315
        filenames = []
316
        SUFFIX = '.png' if is_image else '.npy'
317
        for idx, arr in enumerate(input_arrays):
318
            filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
319
            if is_image:
320
                cv2.imwrite(filename, arr)
321
            else:
322
                np.save(filename, arr)
323
            filenames.append(filename)
324
        return filenames
325
326
    pred_size = (10, 30, 30)
327
    num_classes = 19
328
    ignore_index = 255
329
    results = np.random.randint(0, num_classes, size=pred_size)
330
    labels = np.random.randint(0, num_classes, size=pred_size)
331
    labels[:, 2, 5:10] = ignore_index
332
333
    with tempfile.TemporaryDirectory() as temp_dir:
334
335
        result_files = save_arr(results, 'pred', False, temp_dir)
336
        label_files = save_arr(labels, 'label', True, temp_dir)
337
338
        ret_metrics = eval_metrics(
339
            result_files,
340
            label_files,
341
            num_classes,
342
            ignore_index,
343
            metrics='mIoU')
344
        all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics[
345
            'Acc'], ret_metrics['IoU']
346
        all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
347
                                                  ignore_index)
348
        assert all_acc == all_acc_l
349
        assert np.allclose(acc, acc_l)
350
        assert np.allclose(iou, iou_l)