a b/tests/test_data/test_dataset.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import os
3
import os.path as osp
4
import shutil
5
import tempfile
6
from typing import Generator
7
from unittest.mock import MagicMock, patch
8
9
import numpy as np
10
import pytest
11
import torch
12
from PIL import Image
13
14
from mmseg.core.evaluation import get_classes, get_palette
15
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
16
                            ConcatDataset, CustomDataset, LoveDADataset,
17
                            PascalVOCDataset, RepeatDataset, build_dataset)
18
19
20
def test_classes():
21
    assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes')
22
    assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes(
23
        'pascal_voc')
24
    assert list(
25
        ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
26
27
    with pytest.raises(ValueError):
28
        get_classes('unsupported')
29
30
31
def test_classes_file_path():
32
    tmp_file = tempfile.NamedTemporaryFile()
33
    classes_path = f'{tmp_file.name}.txt'
34
    train_pipeline = [dict(type='LoadImageFromFile')]
35
    kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)
36
37
    # classes.txt with full categories
38
    categories = get_classes('cityscapes')
39
    with open(classes_path, 'w') as f:
40
        f.write('\n'.join(categories))
41
    assert list(CityscapesDataset(**kwargs).CLASSES) == categories
42
43
    # classes.txt with sub categories
44
    categories = ['road', 'sidewalk', 'building']
45
    with open(classes_path, 'w') as f:
46
        f.write('\n'.join(categories))
47
    assert list(CityscapesDataset(**kwargs).CLASSES) == categories
48
49
    # classes.txt with unknown categories
50
    categories = ['road', 'sidewalk', 'unknown']
51
    with open(classes_path, 'w') as f:
52
        f.write('\n'.join(categories))
53
54
    with pytest.raises(ValueError):
55
        CityscapesDataset(**kwargs)
56
57
    tmp_file.close()
58
    os.remove(classes_path)
59
    assert not osp.exists(classes_path)
60
61
62
def test_palette():
63
    assert CityscapesDataset.PALETTE == get_palette('cityscapes')
64
    assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
65
        'pascal_voc')
66
    assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k')
67
68
    with pytest.raises(ValueError):
69
        get_palette('unsupported')
70
71
72
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
73
@patch('mmseg.datasets.CustomDataset.__getitem__',
74
       MagicMock(side_effect=lambda idx: idx))
75
def test_dataset_wrapper():
76
    # CustomDataset.load_annotations = MagicMock()
77
    # CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
78
    dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
79
    len_a = 10
80
    dataset_a.img_infos = MagicMock()
81
    dataset_a.img_infos.__len__.return_value = len_a
82
    dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
83
    len_b = 20
84
    dataset_b.img_infos = MagicMock()
85
    dataset_b.img_infos.__len__.return_value = len_b
86
87
    concat_dataset = ConcatDataset([dataset_a, dataset_b])
88
    assert concat_dataset[5] == 5
89
    assert concat_dataset[25] == 15
90
    assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
91
92
    repeat_dataset = RepeatDataset(dataset_a, 10)
93
    assert repeat_dataset[5] == 5
94
    assert repeat_dataset[15] == 5
95
    assert repeat_dataset[27] == 7
96
    assert len(repeat_dataset) == 10 * len(dataset_a)
97
98
99
def test_custom_dataset():
100
    img_norm_cfg = dict(
101
        mean=[123.675, 116.28, 103.53],
102
        std=[58.395, 57.12, 57.375],
103
        to_rgb=True)
104
    crop_size = (512, 1024)
105
    train_pipeline = [
106
        dict(type='LoadImageFromFile'),
107
        dict(type='LoadAnnotations'),
108
        dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
109
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
110
        dict(type='RandomFlip', prob=0.5),
111
        dict(type='PhotoMetricDistortion'),
112
        dict(type='Normalize', **img_norm_cfg),
113
        dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
114
        dict(type='DefaultFormatBundle'),
115
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
116
    ]
117
    test_pipeline = [
118
        dict(type='LoadImageFromFile'),
119
        dict(
120
            type='MultiScaleFlipAug',
121
            img_scale=(128, 256),
122
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
123
            flip=False,
124
            transforms=[
125
                dict(type='Resize', keep_ratio=True),
126
                dict(type='RandomFlip'),
127
                dict(type='Normalize', **img_norm_cfg),
128
                dict(type='ImageToTensor', keys=['img']),
129
                dict(type='Collect', keys=['img']),
130
            ])
131
    ]
132
133
    # with img_dir and ann_dir
134
    train_dataset = CustomDataset(
135
        train_pipeline,
136
        data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
137
        img_dir='imgs/',
138
        ann_dir='gts/',
139
        img_suffix='img.jpg',
140
        seg_map_suffix='gt.png')
141
    assert len(train_dataset) == 5
142
143
    # with img_dir, ann_dir, split
144
    train_dataset = CustomDataset(
145
        train_pipeline,
146
        data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
147
        img_dir='imgs/',
148
        ann_dir='gts/',
149
        img_suffix='img.jpg',
150
        seg_map_suffix='gt.png',
151
        split='splits/train.txt')
152
    assert len(train_dataset) == 4
153
154
    # no data_root
155
    train_dataset = CustomDataset(
156
        train_pipeline,
157
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
158
        ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
159
        img_suffix='img.jpg',
160
        seg_map_suffix='gt.png')
161
    assert len(train_dataset) == 5
162
163
    # with data_root but img_dir/ann_dir are abs path
164
    train_dataset = CustomDataset(
165
        train_pipeline,
166
        data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
167
        img_dir=osp.abspath(
168
            osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
169
        ann_dir=osp.abspath(
170
            osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')),
171
        img_suffix='img.jpg',
172
        seg_map_suffix='gt.png')
173
    assert len(train_dataset) == 5
174
175
    # test_mode=True
176
    test_dataset = CustomDataset(
177
        test_pipeline,
178
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
179
        img_suffix='img.jpg',
180
        test_mode=True,
181
        classes=('pseudo_class', ))
182
    assert len(test_dataset) == 5
183
184
    # training data get
185
    train_data = train_dataset[0]
186
    assert isinstance(train_data, dict)
187
188
    # test data get
189
    test_data = test_dataset[0]
190
    assert isinstance(test_data, dict)
191
192
    # get gt seg map
193
    gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True)
194
    assert isinstance(gt_seg_maps, Generator)
195
    gt_seg_maps = list(gt_seg_maps)
196
    assert len(gt_seg_maps) == 5
197
198
    # format_results not implemented
199
    with pytest.raises(NotImplementedError):
200
        test_dataset.format_results([], '')
201
202
    pseudo_results = []
203
    for gt_seg_map in gt_seg_maps:
204
        h, w = gt_seg_map.shape
205
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
206
207
    # test past evaluation without CLASSES
208
    with pytest.raises(TypeError):
209
        eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
210
211
    with pytest.raises(TypeError):
212
        eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
213
214
    with pytest.raises(TypeError):
215
        eval_results = train_dataset.evaluate(
216
            pseudo_results, metric=['mDice', 'mIoU'])
217
218
    # test past evaluation with CLASSES
219
    train_dataset.CLASSES = tuple(['a'] * 7)
220
    eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
221
    assert isinstance(eval_results, dict)
222
    assert 'mIoU' in eval_results
223
    assert 'mAcc' in eval_results
224
    assert 'aAcc' in eval_results
225
226
    eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
227
    assert isinstance(eval_results, dict)
228
    assert 'mDice' in eval_results
229
    assert 'mAcc' in eval_results
230
    assert 'aAcc' in eval_results
231
232
    eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
233
    assert isinstance(eval_results, dict)
234
    assert 'mRecall' in eval_results
235
    assert 'mPrecision' in eval_results
236
    assert 'mFscore' in eval_results
237
    assert 'aAcc' in eval_results
238
239
    eval_results = train_dataset.evaluate(
240
        pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
241
    assert isinstance(eval_results, dict)
242
    assert 'mIoU' in eval_results
243
    assert 'mDice' in eval_results
244
    assert 'mAcc' in eval_results
245
    assert 'aAcc' in eval_results
246
    assert 'mFscore' in eval_results
247
    assert 'mPrecision' in eval_results
248
    assert 'mRecall' in eval_results
249
250
    assert not np.isnan(eval_results['mIoU'])
251
    assert not np.isnan(eval_results['mDice'])
252
    assert not np.isnan(eval_results['mAcc'])
253
    assert not np.isnan(eval_results['aAcc'])
254
    assert not np.isnan(eval_results['mFscore'])
255
    assert not np.isnan(eval_results['mPrecision'])
256
    assert not np.isnan(eval_results['mRecall'])
257
258
    # test evaluation with pre-eval and the dataset.CLASSES is necessary
259
    train_dataset.CLASSES = tuple(['a'] * 7)
260
    pseudo_results = []
261
    for idx in range(len(train_dataset)):
262
        h, w = gt_seg_maps[idx].shape
263
        pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
264
        pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx))
265
    eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
266
    assert isinstance(eval_results, dict)
267
    assert 'mIoU' in eval_results
268
    assert 'mAcc' in eval_results
269
    assert 'aAcc' in eval_results
270
271
    eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
272
    assert isinstance(eval_results, dict)
273
    assert 'mDice' in eval_results
274
    assert 'mAcc' in eval_results
275
    assert 'aAcc' in eval_results
276
277
    eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
278
    assert isinstance(eval_results, dict)
279
    assert 'mRecall' in eval_results
280
    assert 'mPrecision' in eval_results
281
    assert 'mFscore' in eval_results
282
    assert 'aAcc' in eval_results
283
284
    eval_results = train_dataset.evaluate(
285
        pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
286
    assert isinstance(eval_results, dict)
287
    assert 'mIoU' in eval_results
288
    assert 'mDice' in eval_results
289
    assert 'mAcc' in eval_results
290
    assert 'aAcc' in eval_results
291
    assert 'mFscore' in eval_results
292
    assert 'mPrecision' in eval_results
293
    assert 'mRecall' in eval_results
294
295
    assert not np.isnan(eval_results['mIoU'])
296
    assert not np.isnan(eval_results['mDice'])
297
    assert not np.isnan(eval_results['mAcc'])
298
    assert not np.isnan(eval_results['aAcc'])
299
    assert not np.isnan(eval_results['mFscore'])
300
    assert not np.isnan(eval_results['mPrecision'])
301
    assert not np.isnan(eval_results['mRecall'])
302
303
304
@pytest.mark.parametrize('separate_eval', [True, False])
305
def test_eval_concat_custom_dataset(separate_eval):
306
    img_norm_cfg = dict(
307
        mean=[123.675, 116.28, 103.53],
308
        std=[58.395, 57.12, 57.375],
309
        to_rgb=True)
310
    test_pipeline = [
311
        dict(type='LoadImageFromFile'),
312
        dict(
313
            type='MultiScaleFlipAug',
314
            img_scale=(128, 256),
315
            # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
316
            flip=False,
317
            transforms=[
318
                dict(type='Resize', keep_ratio=True),
319
                dict(type='RandomFlip'),
320
                dict(type='Normalize', **img_norm_cfg),
321
                dict(type='ImageToTensor', keys=['img']),
322
                dict(type='Collect', keys=['img']),
323
            ])
324
    ]
325
    data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
326
    img_dir = 'imgs/'
327
    ann_dir = 'gts/'
328
329
    cfg1 = dict(
330
        type='CustomDataset',
331
        pipeline=test_pipeline,
332
        data_root=data_root,
333
        img_dir=img_dir,
334
        ann_dir=ann_dir,
335
        img_suffix='img.jpg',
336
        seg_map_suffix='gt.png',
337
        classes=tuple(['a'] * 7))
338
    dataset1 = build_dataset(cfg1)
339
    assert len(dataset1) == 5
340
    # get gt seg map
341
    gt_seg_maps = dataset1.get_gt_seg_maps(efficient_test=True)
342
    assert isinstance(gt_seg_maps, Generator)
343
    gt_seg_maps = list(gt_seg_maps)
344
    assert len(gt_seg_maps) == 5
345
346
    # test past evaluation
347
    pseudo_results = []
348
    for gt_seg_map in gt_seg_maps:
349
        h, w = gt_seg_map.shape
350
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
351
    eval_results1 = dataset1.evaluate(
352
        pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
353
354
    # We use same dir twice for simplicity
355
    # with ann_dir
356
    cfg2 = dict(
357
        type='CustomDataset',
358
        pipeline=test_pipeline,
359
        data_root=data_root,
360
        img_dir=[img_dir, img_dir],
361
        ann_dir=[ann_dir, ann_dir],
362
        img_suffix='img.jpg',
363
        seg_map_suffix='gt.png',
364
        classes=tuple(['a'] * 7),
365
        separate_eval=separate_eval)
366
    dataset2 = build_dataset(cfg2)
367
    assert isinstance(dataset2, ConcatDataset)
368
    assert len(dataset2) == 10
369
370
    eval_results2 = dataset2.evaluate(
371
        pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore'])
372
373
    if separate_eval:
374
        assert eval_results1['mIoU'] == eval_results2[
375
            '0_mIoU'] == eval_results2['1_mIoU']
376
        assert eval_results1['mDice'] == eval_results2[
377
            '0_mDice'] == eval_results2['1_mDice']
378
        assert eval_results1['mAcc'] == eval_results2[
379
            '0_mAcc'] == eval_results2['1_mAcc']
380
        assert eval_results1['aAcc'] == eval_results2[
381
            '0_aAcc'] == eval_results2['1_aAcc']
382
        assert eval_results1['mFscore'] == eval_results2[
383
            '0_mFscore'] == eval_results2['1_mFscore']
384
        assert eval_results1['mPrecision'] == eval_results2[
385
            '0_mPrecision'] == eval_results2['1_mPrecision']
386
        assert eval_results1['mRecall'] == eval_results2[
387
            '0_mRecall'] == eval_results2['1_mRecall']
388
    else:
389
        assert eval_results1['mIoU'] == eval_results2['mIoU']
390
        assert eval_results1['mDice'] == eval_results2['mDice']
391
        assert eval_results1['mAcc'] == eval_results2['mAcc']
392
        assert eval_results1['aAcc'] == eval_results2['aAcc']
393
        assert eval_results1['mFscore'] == eval_results2['mFscore']
394
        assert eval_results1['mPrecision'] == eval_results2['mPrecision']
395
        assert eval_results1['mRecall'] == eval_results2['mRecall']
396
397
    # test get dataset_idx and sample_idx from ConcateDataset
398
    dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3)
399
    assert dataset_idx == 0
400
    assert sample_idx == 3
401
402
    dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(7)
403
    assert dataset_idx == 1
404
    assert sample_idx == 2
405
406
    dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-7)
407
    assert dataset_idx == 0
408
    assert sample_idx == 3
409
410
    # test negative indice exceed length of dataset
411
    with pytest.raises(ValueError):
412
        dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-11)
413
414
    # test negative indice value
415
    indice = -6
416
    dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice)
417
    dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx(
418
        len(dataset2) + indice)
419
    assert dataset_idx1 == dataset_idx2
420
    assert sample_idx1 == sample_idx2
421
422
    # test evaluation with pre-eval and the dataset.CLASSES is necessary
423
    pseudo_results = []
424
    eval_results1 = []
425
    for idx in range(len(dataset1)):
426
        h, w = gt_seg_maps[idx].shape
427
        pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
428
        pseudo_results.append(pseudo_result)
429
        eval_results1.extend(dataset1.pre_eval(pseudo_result, idx))
430
431
    assert len(eval_results1) == len(dataset1)
432
    assert isinstance(eval_results1[0], tuple)
433
    assert len(eval_results1[0]) == 4
434
    assert isinstance(eval_results1[0][0], torch.Tensor)
435
436
    eval_results1 = dataset1.evaluate(
437
        eval_results1, metric=['mIoU', 'mDice', 'mFscore'])
438
439
    pseudo_results = pseudo_results * 2
440
    eval_results2 = []
441
    for idx in range(len(dataset2)):
442
        eval_results2.extend(dataset2.pre_eval(pseudo_results[idx], idx))
443
444
    assert len(eval_results2) == len(dataset2)
445
    assert isinstance(eval_results2[0], tuple)
446
    assert len(eval_results2[0]) == 4
447
    assert isinstance(eval_results2[0][0], torch.Tensor)
448
449
    eval_results2 = dataset2.evaluate(
450
        eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
451
452
    if separate_eval:
453
        assert eval_results1['mIoU'] == eval_results2[
454
            '0_mIoU'] == eval_results2['1_mIoU']
455
        assert eval_results1['mDice'] == eval_results2[
456
            '0_mDice'] == eval_results2['1_mDice']
457
        assert eval_results1['mAcc'] == eval_results2[
458
            '0_mAcc'] == eval_results2['1_mAcc']
459
        assert eval_results1['aAcc'] == eval_results2[
460
            '0_aAcc'] == eval_results2['1_aAcc']
461
        assert eval_results1['mFscore'] == eval_results2[
462
            '0_mFscore'] == eval_results2['1_mFscore']
463
        assert eval_results1['mPrecision'] == eval_results2[
464
            '0_mPrecision'] == eval_results2['1_mPrecision']
465
        assert eval_results1['mRecall'] == eval_results2[
466
            '0_mRecall'] == eval_results2['1_mRecall']
467
    else:
468
        assert eval_results1['mIoU'] == eval_results2['mIoU']
469
        assert eval_results1['mDice'] == eval_results2['mDice']
470
        assert eval_results1['mAcc'] == eval_results2['mAcc']
471
        assert eval_results1['aAcc'] == eval_results2['aAcc']
472
        assert eval_results1['mFscore'] == eval_results2['mFscore']
473
        assert eval_results1['mPrecision'] == eval_results2['mPrecision']
474
        assert eval_results1['mRecall'] == eval_results2['mRecall']
475
476
    # test batch_indices for pre eval
477
    eval_results2 = dataset2.pre_eval(pseudo_results,
478
                                      list(range(len(pseudo_results))))
479
480
    assert len(eval_results2) == len(dataset2)
481
    assert isinstance(eval_results2[0], tuple)
482
    assert len(eval_results2[0]) == 4
483
    assert isinstance(eval_results2[0][0], torch.Tensor)
484
485
    eval_results2 = dataset2.evaluate(
486
        eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
487
488
    if separate_eval:
489
        assert eval_results1['mIoU'] == eval_results2[
490
            '0_mIoU'] == eval_results2['1_mIoU']
491
        assert eval_results1['mDice'] == eval_results2[
492
            '0_mDice'] == eval_results2['1_mDice']
493
        assert eval_results1['mAcc'] == eval_results2[
494
            '0_mAcc'] == eval_results2['1_mAcc']
495
        assert eval_results1['aAcc'] == eval_results2[
496
            '0_aAcc'] == eval_results2['1_aAcc']
497
        assert eval_results1['mFscore'] == eval_results2[
498
            '0_mFscore'] == eval_results2['1_mFscore']
499
        assert eval_results1['mPrecision'] == eval_results2[
500
            '0_mPrecision'] == eval_results2['1_mPrecision']
501
        assert eval_results1['mRecall'] == eval_results2[
502
            '0_mRecall'] == eval_results2['1_mRecall']
503
    else:
504
        assert eval_results1['mIoU'] == eval_results2['mIoU']
505
        assert eval_results1['mDice'] == eval_results2['mDice']
506
        assert eval_results1['mAcc'] == eval_results2['mAcc']
507
        assert eval_results1['aAcc'] == eval_results2['aAcc']
508
        assert eval_results1['mFscore'] == eval_results2['mFscore']
509
        assert eval_results1['mPrecision'] == eval_results2['mPrecision']
510
        assert eval_results1['mRecall'] == eval_results2['mRecall']
511
512
513
def test_ade():
514
    test_dataset = ADE20KDataset(
515
        pipeline=[],
516
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
517
    assert len(test_dataset) == 5
518
519
    # Test format_results
520
    pseudo_results = []
521
    for _ in range(len(test_dataset)):
522
        h, w = (2, 2)
523
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
524
525
    file_paths = test_dataset.format_results(pseudo_results, '.format_ade')
526
    assert len(file_paths) == len(test_dataset)
527
    temp = np.array(Image.open(file_paths[0]))
528
    assert np.allclose(temp, pseudo_results[0] + 1)
529
530
    shutil.rmtree('.format_ade')
531
532
533
@pytest.mark.parametrize('separate_eval', [True, False])
534
def test_concat_ade(separate_eval):
535
    test_dataset = ADE20KDataset(
536
        pipeline=[],
537
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
538
    assert len(test_dataset) == 5
539
540
    concat_dataset = ConcatDataset([test_dataset, test_dataset],
541
                                   separate_eval=separate_eval)
542
    assert len(concat_dataset) == 10
543
    # Test format_results
544
    pseudo_results = []
545
    for _ in range(len(concat_dataset)):
546
        h, w = (2, 2)
547
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
548
549
    # test format per image
550
    file_paths = []
551
    for i in range(len(pseudo_results)):
552
        file_paths.extend(
553
            concat_dataset.format_results([pseudo_results[i]],
554
                                          '.format_ade',
555
                                          indices=[i]))
556
    assert len(file_paths) == len(concat_dataset)
557
    temp = np.array(Image.open(file_paths[0]))
558
    assert np.allclose(temp, pseudo_results[0] + 1)
559
560
    shutil.rmtree('.format_ade')
561
562
    # test default argument
563
    file_paths = concat_dataset.format_results(pseudo_results, '.format_ade')
564
    assert len(file_paths) == len(concat_dataset)
565
    temp = np.array(Image.open(file_paths[0]))
566
    assert np.allclose(temp, pseudo_results[0] + 1)
567
568
    shutil.rmtree('.format_ade')
569
570
571
def test_cityscapes():
572
    test_dataset = CityscapesDataset(
573
        pipeline=[],
574
        img_dir=osp.join(
575
            osp.dirname(__file__),
576
            '../data/pseudo_cityscapes_dataset/leftImg8bit'),
577
        ann_dir=osp.join(
578
            osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
579
    assert len(test_dataset) == 1
580
581
    gt_seg_maps = list(test_dataset.get_gt_seg_maps())
582
583
    # Test format_results
584
    pseudo_results = []
585
    for idx in range(len(test_dataset)):
586
        h, w = gt_seg_maps[idx].shape
587
        pseudo_results.append(np.random.randint(low=0, high=19, size=(h, w)))
588
589
    file_paths = test_dataset.format_results(pseudo_results, '.format_city')
590
    assert len(file_paths) == len(test_dataset)
591
    temp = np.array(Image.open(file_paths[0]))
592
    assert np.allclose(temp,
593
                       test_dataset._convert_to_label_id(pseudo_results[0]))
594
595
    # Test cityscapes evaluate
596
597
    test_dataset.evaluate(
598
        pseudo_results, metric='cityscapes', imgfile_prefix='.format_city')
599
600
    shutil.rmtree('.format_city')
601
602
603
@pytest.mark.parametrize('separate_eval', [True, False])
604
def test_concat_cityscapes(separate_eval):
605
    cityscape_dataset = CityscapesDataset(
606
        pipeline=[],
607
        img_dir=osp.join(
608
            osp.dirname(__file__),
609
            '../data/pseudo_cityscapes_dataset/leftImg8bit'),
610
        ann_dir=osp.join(
611
            osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
612
    assert len(cityscape_dataset) == 1
613
    with pytest.raises(NotImplementedError):
614
        _ = ConcatDataset([cityscape_dataset, cityscape_dataset],
615
                          separate_eval=separate_eval)
616
    ade_dataset = ADE20KDataset(
617
        pipeline=[],
618
        img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
619
    assert len(ade_dataset) == 5
620
    with pytest.raises(NotImplementedError):
621
        _ = ConcatDataset([cityscape_dataset, ade_dataset],
622
                          separate_eval=separate_eval)
623
624
625
def test_loveda():
626
    test_dataset = LoveDADataset(
627
        pipeline=[],
628
        img_dir=osp.join(
629
            osp.dirname(__file__), '../data/pseudo_loveda_dataset/img_dir'),
630
        ann_dir=osp.join(
631
            osp.dirname(__file__), '../data/pseudo_loveda_dataset/ann_dir'))
632
    assert len(test_dataset) == 3
633
634
    gt_seg_maps = list(test_dataset.get_gt_seg_maps())
635
636
    # Test format_results
637
    pseudo_results = []
638
    for idx in range(len(test_dataset)):
639
        h, w = gt_seg_maps[idx].shape
640
        pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
641
    file_paths = test_dataset.format_results(pseudo_results, '.format_loveda')
642
    assert len(file_paths) == len(test_dataset)
643
    # Test loveda evaluate
644
645
    test_dataset.evaluate(
646
        pseudo_results, metric='mIoU', imgfile_prefix='.format_loveda')
647
648
    shutil.rmtree('.format_loveda')
649
650
651
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
652
@patch('mmseg.datasets.CustomDataset.__getitem__',
653
       MagicMock(side_effect=lambda idx: idx))
654
@pytest.mark.parametrize('dataset, classes', [
655
    ('ADE20KDataset', ('wall', 'building')),
656
    ('CityscapesDataset', ('road', 'sidewalk')),
657
    ('CustomDataset', ('bus', 'car')),
658
    ('PascalVOCDataset', ('aeroplane', 'bicycle')),
659
])
660
def test_custom_classes_override_default(dataset, classes):
661
662
    dataset_class = DATASETS.get(dataset)
663
664
    original_classes = dataset_class.CLASSES
665
666
    # Test setting classes as a tuple
667
    custom_dataset = dataset_class(
668
        pipeline=[],
669
        img_dir=MagicMock(),
670
        split=MagicMock(),
671
        classes=classes,
672
        test_mode=True)
673
674
    assert custom_dataset.CLASSES != original_classes
675
    assert custom_dataset.CLASSES == classes
676
677
    # Test setting classes as a list
678
    custom_dataset = dataset_class(
679
        pipeline=[],
680
        img_dir=MagicMock(),
681
        split=MagicMock(),
682
        classes=list(classes),
683
        test_mode=True)
684
685
    assert custom_dataset.CLASSES != original_classes
686
    assert custom_dataset.CLASSES == list(classes)
687
688
    # Test overriding not a subset
689
    custom_dataset = dataset_class(
690
        pipeline=[],
691
        img_dir=MagicMock(),
692
        split=MagicMock(),
693
        classes=[classes[0]],
694
        test_mode=True)
695
696
    assert custom_dataset.CLASSES != original_classes
697
    assert custom_dataset.CLASSES == [classes[0]]
698
699
    # Test default behavior
700
    if dataset_class is CustomDataset:
701
        with pytest.raises(AssertionError):
702
            custom_dataset = dataset_class(
703
                pipeline=[],
704
                img_dir=MagicMock(),
705
                split=MagicMock(),
706
                classes=None,
707
                test_mode=True)
708
    else:
709
        custom_dataset = dataset_class(
710
            pipeline=[],
711
            img_dir=MagicMock(),
712
            split=MagicMock(),
713
            classes=None,
714
            test_mode=True)
715
716
        assert custom_dataset.CLASSES == original_classes
717
718
719
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
720
@patch('mmseg.datasets.CustomDataset.__getitem__',
721
       MagicMock(side_effect=lambda idx: idx))
722
def test_custom_dataset_random_palette_is_generated():
723
    dataset = CustomDataset(
724
        pipeline=[],
725
        img_dir=MagicMock(),
726
        split=MagicMock(),
727
        classes=('bus', 'car'),
728
        test_mode=True)
729
    assert len(dataset.PALETTE) == 2
730
    for class_color in dataset.PALETTE:
731
        assert len(class_color) == 3
732
        assert all(x >= 0 and x <= 255 for x in class_color)
733
734
735
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
736
@patch('mmseg.datasets.CustomDataset.__getitem__',
737
       MagicMock(side_effect=lambda idx: idx))
738
def test_custom_dataset_custom_palette():
739
    dataset = CustomDataset(
740
        pipeline=[],
741
        img_dir=MagicMock(),
742
        split=MagicMock(),
743
        classes=('bus', 'car'),
744
        palette=[[100, 100, 100], [200, 200, 200]],
745
        test_mode=True)
746
    assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])