--- a +++ b/tests/test_data/test_dataset.py @@ -0,0 +1,746 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import shutil +import tempfile +from typing import Generator +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch +from PIL import Image + +from mmseg.core.evaluation import get_classes, get_palette +from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, + ConcatDataset, CustomDataset, LoveDADataset, + PascalVOCDataset, RepeatDataset, build_dataset) + + +def test_classes(): + assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes') + assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes( + 'pascal_voc') + assert list( + ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k') + + with pytest.raises(ValueError): + get_classes('unsupported') + + +def test_classes_file_path(): + tmp_file = tempfile.NamedTemporaryFile() + classes_path = f'{tmp_file.name}.txt' + train_pipeline = [dict(type='LoadImageFromFile')] + kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path) + + # classes.txt with full categories + categories = get_classes('cityscapes') + with open(classes_path, 'w') as f: + f.write('\n'.join(categories)) + assert list(CityscapesDataset(**kwargs).CLASSES) == categories + + # classes.txt with sub categories + categories = ['road', 'sidewalk', 'building'] + with open(classes_path, 'w') as f: + f.write('\n'.join(categories)) + assert list(CityscapesDataset(**kwargs).CLASSES) == categories + + # classes.txt with unknown categories + categories = ['road', 'sidewalk', 'unknown'] + with open(classes_path, 'w') as f: + f.write('\n'.join(categories)) + + with pytest.raises(ValueError): + CityscapesDataset(**kwargs) + + tmp_file.close() + os.remove(classes_path) + assert not osp.exists(classes_path) + + +def test_palette(): + assert CityscapesDataset.PALETTE == get_palette('cityscapes') + assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette( + 'pascal_voc') + assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k') + + with pytest.raises(ValueError): + get_palette('unsupported') + + +@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmseg.datasets.CustomDataset.__getitem__', + MagicMock(side_effect=lambda idx: idx)) +def test_dataset_wrapper(): + # CustomDataset.load_annotations = MagicMock() + # CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx) + dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[]) + len_a = 10 + dataset_a.img_infos = MagicMock() + dataset_a.img_infos.__len__.return_value = len_a + dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[]) + len_b = 20 + dataset_b.img_infos = MagicMock() + dataset_b.img_infos.__len__.return_value = len_b + + concat_dataset = ConcatDataset([dataset_a, dataset_b]) + assert concat_dataset[5] == 5 + assert concat_dataset[25] == 15 + assert len(concat_dataset) == len(dataset_a) + len(dataset_b) + + repeat_dataset = RepeatDataset(dataset_a, 10) + assert repeat_dataset[5] == 5 + assert repeat_dataset[15] == 5 + assert repeat_dataset[27] == 7 + assert len(repeat_dataset) == 10 * len(dataset_a) + + +def test_custom_dataset(): + img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + crop_size = (512, 1024) + train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), + ] + test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(128, 256), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) + ] + + # with img_dir and ann_dir + train_dataset = CustomDataset( + train_pipeline, + data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), + img_dir='imgs/', + ann_dir='gts/', + img_suffix='img.jpg', + seg_map_suffix='gt.png') + assert len(train_dataset) == 5 + + # with img_dir, ann_dir, split + train_dataset = CustomDataset( + train_pipeline, + data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), + img_dir='imgs/', + ann_dir='gts/', + img_suffix='img.jpg', + seg_map_suffix='gt.png', + split='splits/train.txt') + assert len(train_dataset) == 4 + + # no data_root + train_dataset = CustomDataset( + train_pipeline, + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'), + ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'), + img_suffix='img.jpg', + seg_map_suffix='gt.png') + assert len(train_dataset) == 5 + + # with data_root but img_dir/ann_dir are abs path + train_dataset = CustomDataset( + train_pipeline, + data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), + img_dir=osp.abspath( + osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')), + ann_dir=osp.abspath( + osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')), + img_suffix='img.jpg', + seg_map_suffix='gt.png') + assert len(train_dataset) == 5 + + # test_mode=True + test_dataset = CustomDataset( + test_pipeline, + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'), + img_suffix='img.jpg', + test_mode=True, + classes=('pseudo_class', )) + assert len(test_dataset) == 5 + + # training data get + train_data = train_dataset[0] + assert isinstance(train_data, dict) + + # test data get + test_data = test_dataset[0] + assert isinstance(test_data, dict) + + # get gt seg map + gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True) + assert isinstance(gt_seg_maps, Generator) + gt_seg_maps = list(gt_seg_maps) + assert len(gt_seg_maps) == 5 + + # format_results not implemented + with pytest.raises(NotImplementedError): + test_dataset.format_results([], '') + + pseudo_results = [] + for gt_seg_map in gt_seg_maps: + h, w = gt_seg_map.shape + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + + # test past evaluation without CLASSES + with pytest.raises(TypeError): + eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU']) + + with pytest.raises(TypeError): + eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + + with pytest.raises(TypeError): + eval_results = train_dataset.evaluate( + pseudo_results, metric=['mDice', 'mIoU']) + + # test past evaluation with CLASSES + train_dataset.CLASSES = tuple(['a'] * 7) + eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU') + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + assert isinstance(eval_results, dict) + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore') + assert isinstance(eval_results, dict) + assert 'mRecall' in eval_results + assert 'mPrecision' in eval_results + assert 'mFscore' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate( + pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + assert 'mFscore' in eval_results + assert 'mPrecision' in eval_results + assert 'mRecall' in eval_results + + assert not np.isnan(eval_results['mIoU']) + assert not np.isnan(eval_results['mDice']) + assert not np.isnan(eval_results['mAcc']) + assert not np.isnan(eval_results['aAcc']) + assert not np.isnan(eval_results['mFscore']) + assert not np.isnan(eval_results['mPrecision']) + assert not np.isnan(eval_results['mRecall']) + + # test evaluation with pre-eval and the dataset.CLASSES is necessary + train_dataset.CLASSES = tuple(['a'] * 7) + pseudo_results = [] + for idx in range(len(train_dataset)): + h, w = gt_seg_maps[idx].shape + pseudo_result = np.random.randint(low=0, high=7, size=(h, w)) + pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx)) + eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU']) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + assert isinstance(eval_results, dict) + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore') + assert isinstance(eval_results, dict) + assert 'mRecall' in eval_results + assert 'mPrecision' in eval_results + assert 'mFscore' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate( + pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + assert 'mFscore' in eval_results + assert 'mPrecision' in eval_results + assert 'mRecall' in eval_results + + assert not np.isnan(eval_results['mIoU']) + assert not np.isnan(eval_results['mDice']) + assert not np.isnan(eval_results['mAcc']) + assert not np.isnan(eval_results['aAcc']) + assert not np.isnan(eval_results['mFscore']) + assert not np.isnan(eval_results['mPrecision']) + assert not np.isnan(eval_results['mRecall']) + + +@pytest.mark.parametrize('separate_eval', [True, False]) +def test_eval_concat_custom_dataset(separate_eval): + img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(128, 256), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) + ] + data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') + img_dir = 'imgs/' + ann_dir = 'gts/' + + cfg1 = dict( + type='CustomDataset', + pipeline=test_pipeline, + data_root=data_root, + img_dir=img_dir, + ann_dir=ann_dir, + img_suffix='img.jpg', + seg_map_suffix='gt.png', + classes=tuple(['a'] * 7)) + dataset1 = build_dataset(cfg1) + assert len(dataset1) == 5 + # get gt seg map + gt_seg_maps = dataset1.get_gt_seg_maps(efficient_test=True) + assert isinstance(gt_seg_maps, Generator) + gt_seg_maps = list(gt_seg_maps) + assert len(gt_seg_maps) == 5 + + # test past evaluation + pseudo_results = [] + for gt_seg_map in gt_seg_maps: + h, w = gt_seg_map.shape + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + eval_results1 = dataset1.evaluate( + pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + + # We use same dir twice for simplicity + # with ann_dir + cfg2 = dict( + type='CustomDataset', + pipeline=test_pipeline, + data_root=data_root, + img_dir=[img_dir, img_dir], + ann_dir=[ann_dir, ann_dir], + img_suffix='img.jpg', + seg_map_suffix='gt.png', + classes=tuple(['a'] * 7), + separate_eval=separate_eval) + dataset2 = build_dataset(cfg2) + assert isinstance(dataset2, ConcatDataset) + assert len(dataset2) == 10 + + eval_results2 = dataset2.evaluate( + pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore']) + + if separate_eval: + assert eval_results1['mIoU'] == eval_results2[ + '0_mIoU'] == eval_results2['1_mIoU'] + assert eval_results1['mDice'] == eval_results2[ + '0_mDice'] == eval_results2['1_mDice'] + assert eval_results1['mAcc'] == eval_results2[ + '0_mAcc'] == eval_results2['1_mAcc'] + assert eval_results1['aAcc'] == eval_results2[ + '0_aAcc'] == eval_results2['1_aAcc'] + assert eval_results1['mFscore'] == eval_results2[ + '0_mFscore'] == eval_results2['1_mFscore'] + assert eval_results1['mPrecision'] == eval_results2[ + '0_mPrecision'] == eval_results2['1_mPrecision'] + assert eval_results1['mRecall'] == eval_results2[ + '0_mRecall'] == eval_results2['1_mRecall'] + else: + assert eval_results1['mIoU'] == eval_results2['mIoU'] + assert eval_results1['mDice'] == eval_results2['mDice'] + assert eval_results1['mAcc'] == eval_results2['mAcc'] + assert eval_results1['aAcc'] == eval_results2['aAcc'] + assert eval_results1['mFscore'] == eval_results2['mFscore'] + assert eval_results1['mPrecision'] == eval_results2['mPrecision'] + assert eval_results1['mRecall'] == eval_results2['mRecall'] + + # test get dataset_idx and sample_idx from ConcateDataset + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3) + assert dataset_idx == 0 + assert sample_idx == 3 + + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(7) + assert dataset_idx == 1 + assert sample_idx == 2 + + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-7) + assert dataset_idx == 0 + assert sample_idx == 3 + + # test negative indice exceed length of dataset + with pytest.raises(ValueError): + dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-11) + + # test negative indice value + indice = -6 + dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice) + dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx( + len(dataset2) + indice) + assert dataset_idx1 == dataset_idx2 + assert sample_idx1 == sample_idx2 + + # test evaluation with pre-eval and the dataset.CLASSES is necessary + pseudo_results = [] + eval_results1 = [] + for idx in range(len(dataset1)): + h, w = gt_seg_maps[idx].shape + pseudo_result = np.random.randint(low=0, high=7, size=(h, w)) + pseudo_results.append(pseudo_result) + eval_results1.extend(dataset1.pre_eval(pseudo_result, idx)) + + assert len(eval_results1) == len(dataset1) + assert isinstance(eval_results1[0], tuple) + assert len(eval_results1[0]) == 4 + assert isinstance(eval_results1[0][0], torch.Tensor) + + eval_results1 = dataset1.evaluate( + eval_results1, metric=['mIoU', 'mDice', 'mFscore']) + + pseudo_results = pseudo_results * 2 + eval_results2 = [] + for idx in range(len(dataset2)): + eval_results2.extend(dataset2.pre_eval(pseudo_results[idx], idx)) + + assert len(eval_results2) == len(dataset2) + assert isinstance(eval_results2[0], tuple) + assert len(eval_results2[0]) == 4 + assert isinstance(eval_results2[0][0], torch.Tensor) + + eval_results2 = dataset2.evaluate( + eval_results2, metric=['mIoU', 'mDice', 'mFscore']) + + if separate_eval: + assert eval_results1['mIoU'] == eval_results2[ + '0_mIoU'] == eval_results2['1_mIoU'] + assert eval_results1['mDice'] == eval_results2[ + '0_mDice'] == eval_results2['1_mDice'] + assert eval_results1['mAcc'] == eval_results2[ + '0_mAcc'] == eval_results2['1_mAcc'] + assert eval_results1['aAcc'] == eval_results2[ + '0_aAcc'] == eval_results2['1_aAcc'] + assert eval_results1['mFscore'] == eval_results2[ + '0_mFscore'] == eval_results2['1_mFscore'] + assert eval_results1['mPrecision'] == eval_results2[ + '0_mPrecision'] == eval_results2['1_mPrecision'] + assert eval_results1['mRecall'] == eval_results2[ + '0_mRecall'] == eval_results2['1_mRecall'] + else: + assert eval_results1['mIoU'] == eval_results2['mIoU'] + assert eval_results1['mDice'] == eval_results2['mDice'] + assert eval_results1['mAcc'] == eval_results2['mAcc'] + assert eval_results1['aAcc'] == eval_results2['aAcc'] + assert eval_results1['mFscore'] == eval_results2['mFscore'] + assert eval_results1['mPrecision'] == eval_results2['mPrecision'] + assert eval_results1['mRecall'] == eval_results2['mRecall'] + + # test batch_indices for pre eval + eval_results2 = dataset2.pre_eval(pseudo_results, + list(range(len(pseudo_results)))) + + assert len(eval_results2) == len(dataset2) + assert isinstance(eval_results2[0], tuple) + assert len(eval_results2[0]) == 4 + assert isinstance(eval_results2[0][0], torch.Tensor) + + eval_results2 = dataset2.evaluate( + eval_results2, metric=['mIoU', 'mDice', 'mFscore']) + + if separate_eval: + assert eval_results1['mIoU'] == eval_results2[ + '0_mIoU'] == eval_results2['1_mIoU'] + assert eval_results1['mDice'] == eval_results2[ + '0_mDice'] == eval_results2['1_mDice'] + assert eval_results1['mAcc'] == eval_results2[ + '0_mAcc'] == eval_results2['1_mAcc'] + assert eval_results1['aAcc'] == eval_results2[ + '0_aAcc'] == eval_results2['1_aAcc'] + assert eval_results1['mFscore'] == eval_results2[ + '0_mFscore'] == eval_results2['1_mFscore'] + assert eval_results1['mPrecision'] == eval_results2[ + '0_mPrecision'] == eval_results2['1_mPrecision'] + assert eval_results1['mRecall'] == eval_results2[ + '0_mRecall'] == eval_results2['1_mRecall'] + else: + assert eval_results1['mIoU'] == eval_results2['mIoU'] + assert eval_results1['mDice'] == eval_results2['mDice'] + assert eval_results1['mAcc'] == eval_results2['mAcc'] + assert eval_results1['aAcc'] == eval_results2['aAcc'] + assert eval_results1['mFscore'] == eval_results2['mFscore'] + assert eval_results1['mPrecision'] == eval_results2['mPrecision'] + assert eval_results1['mRecall'] == eval_results2['mRecall'] + + +def test_ade(): + test_dataset = ADE20KDataset( + pipeline=[], + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + assert len(test_dataset) == 5 + + # Test format_results + pseudo_results = [] + for _ in range(len(test_dataset)): + h, w = (2, 2) + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + + file_paths = test_dataset.format_results(pseudo_results, '.format_ade') + assert len(file_paths) == len(test_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, pseudo_results[0] + 1) + + shutil.rmtree('.format_ade') + + +@pytest.mark.parametrize('separate_eval', [True, False]) +def test_concat_ade(separate_eval): + test_dataset = ADE20KDataset( + pipeline=[], + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + assert len(test_dataset) == 5 + + concat_dataset = ConcatDataset([test_dataset, test_dataset], + separate_eval=separate_eval) + assert len(concat_dataset) == 10 + # Test format_results + pseudo_results = [] + for _ in range(len(concat_dataset)): + h, w = (2, 2) + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + + # test format per image + file_paths = [] + for i in range(len(pseudo_results)): + file_paths.extend( + concat_dataset.format_results([pseudo_results[i]], + '.format_ade', + indices=[i])) + assert len(file_paths) == len(concat_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, pseudo_results[0] + 1) + + shutil.rmtree('.format_ade') + + # test default argument + file_paths = concat_dataset.format_results(pseudo_results, '.format_ade') + assert len(file_paths) == len(concat_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, pseudo_results[0] + 1) + + shutil.rmtree('.format_ade') + + +def test_cityscapes(): + test_dataset = CityscapesDataset( + pipeline=[], + img_dir=osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/leftImg8bit'), + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine')) + assert len(test_dataset) == 1 + + gt_seg_maps = list(test_dataset.get_gt_seg_maps()) + + # Test format_results + pseudo_results = [] + for idx in range(len(test_dataset)): + h, w = gt_seg_maps[idx].shape + pseudo_results.append(np.random.randint(low=0, high=19, size=(h, w))) + + file_paths = test_dataset.format_results(pseudo_results, '.format_city') + assert len(file_paths) == len(test_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, + test_dataset._convert_to_label_id(pseudo_results[0])) + + # Test cityscapes evaluate + + test_dataset.evaluate( + pseudo_results, metric='cityscapes', imgfile_prefix='.format_city') + + shutil.rmtree('.format_city') + + +@pytest.mark.parametrize('separate_eval', [True, False]) +def test_concat_cityscapes(separate_eval): + cityscape_dataset = CityscapesDataset( + pipeline=[], + img_dir=osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/leftImg8bit'), + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine')) + assert len(cityscape_dataset) == 1 + with pytest.raises(NotImplementedError): + _ = ConcatDataset([cityscape_dataset, cityscape_dataset], + separate_eval=separate_eval) + ade_dataset = ADE20KDataset( + pipeline=[], + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + assert len(ade_dataset) == 5 + with pytest.raises(NotImplementedError): + _ = ConcatDataset([cityscape_dataset, ade_dataset], + separate_eval=separate_eval) + + +def test_loveda(): + test_dataset = LoveDADataset( + pipeline=[], + img_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_loveda_dataset/img_dir'), + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_loveda_dataset/ann_dir')) + assert len(test_dataset) == 3 + + gt_seg_maps = list(test_dataset.get_gt_seg_maps()) + + # Test format_results + pseudo_results = [] + for idx in range(len(test_dataset)): + h, w = gt_seg_maps[idx].shape + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + file_paths = test_dataset.format_results(pseudo_results, '.format_loveda') + assert len(file_paths) == len(test_dataset) + # Test loveda evaluate + + test_dataset.evaluate( + pseudo_results, metric='mIoU', imgfile_prefix='.format_loveda') + + shutil.rmtree('.format_loveda') + + +@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmseg.datasets.CustomDataset.__getitem__', + MagicMock(side_effect=lambda idx: idx)) +@pytest.mark.parametrize('dataset, classes', [ + ('ADE20KDataset', ('wall', 'building')), + ('CityscapesDataset', ('road', 'sidewalk')), + ('CustomDataset', ('bus', 'car')), + ('PascalVOCDataset', ('aeroplane', 'bicycle')), +]) +def test_custom_classes_override_default(dataset, classes): + + dataset_class = DATASETS.get(dataset) + + original_classes = dataset_class.CLASSES + + # Test setting classes as a tuple + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=classes, + test_mode=True) + + assert custom_dataset.CLASSES != original_classes + assert custom_dataset.CLASSES == classes + + # Test setting classes as a list + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=list(classes), + test_mode=True) + + assert custom_dataset.CLASSES != original_classes + assert custom_dataset.CLASSES == list(classes) + + # Test overriding not a subset + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=[classes[0]], + test_mode=True) + + assert custom_dataset.CLASSES != original_classes + assert custom_dataset.CLASSES == [classes[0]] + + # Test default behavior + if dataset_class is CustomDataset: + with pytest.raises(AssertionError): + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=None, + test_mode=True) + else: + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=None, + test_mode=True) + + assert custom_dataset.CLASSES == original_classes + + +@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmseg.datasets.CustomDataset.__getitem__', + MagicMock(side_effect=lambda idx: idx)) +def test_custom_dataset_random_palette_is_generated(): + dataset = CustomDataset( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=('bus', 'car'), + test_mode=True) + assert len(dataset.PALETTE) == 2 + for class_color in dataset.PALETTE: + assert len(class_color) == 3 + assert all(x >= 0 and x <= 255 for x in class_color) + + +@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmseg.datasets.CustomDataset.__getitem__', + MagicMock(side_effect=lambda idx: idx)) +def test_custom_dataset_custom_palette(): + dataset = CustomDataset( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=('bus', 'car'), + palette=[[100, 100, 100], [200, 200, 200]], + test_mode=True) + assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])