Diff of /tools/browse_dataset.py [000000] .. [4e96d3]

Switch to unified view

a b/tools/browse_dataset.py
1
import argparse
2
import os
3
import warnings
4
from pathlib import Path
5
6
import mmcv
7
import numpy as np
8
from mmcv import Config
9
10
from mmseg.datasets.builder import build_dataset
11
12
13
def parse_args():
14
    parser = argparse.ArgumentParser(description='Browse a dataset')
15
    parser.add_argument('config', help='train config file path')
16
    parser.add_argument(
17
        '--show-origin',
18
        default=False,
19
        action='store_true',
20
        help='if True, omit all augmentation in pipeline,'
21
        ' show origin image and seg map')
22
    parser.add_argument(
23
        '--skip-type',
24
        type=str,
25
        nargs='+',
26
        default=['DefaultFormatBundle', 'Normalize', 'Collect'],
27
        help='skip some useless pipeline,if `show-origin` is true, '
28
        'all pipeline except `Load` will be skipped')
29
    parser.add_argument(
30
        '--output-dir',
31
        default='./output',
32
        type=str,
33
        help='If there is no display interface, you can save it')
34
    parser.add_argument('--show', default=False, action='store_true')
35
    parser.add_argument(
36
        '--show-interval',
37
        type=int,
38
        default=999,
39
        help='the interval of show (ms)')
40
    parser.add_argument(
41
        '--opacity',
42
        type=float,
43
        default=0.5,
44
        help='the opacity of semantic map')
45
    args = parser.parse_args()
46
    return args
47
48
49
def imshow_semantic(img,
50
                    seg,
51
                    class_names,
52
                    palette=None,
53
                    win_name='',
54
                    show=False,
55
                    wait_time=0,
56
                    out_file=None,
57
                    opacity=0.5):
58
    """Draw `result` over `img`.
59
60
    Args:
61
        img (str or Tensor): The image to be displayed.
62
        seg (Tensor): The semantic segmentation results to draw over
63
            `img`.
64
        class_names (list[str]): Names of each classes.
65
        palette (list[list[int]]] | np.ndarray | None): The palette of
66
            segmentation map. If None is given, random palette will be
67
            generated. Default: None
68
        win_name (str): The window name.
69
        wait_time (int): Value of waitKey param.
70
            Default: 0.
71
        show (bool): Whether to show the image.
72
            Default: False.
73
        out_file (str or None): The filename to write the image.
74
            Default: None.
75
        opacity(float): Opacity of painted segmentation map.
76
            Default 0.5.
77
            Must be in (0, 1] range.
78
    Returns:
79
        img (Tensor): Only if not `show` or `out_file`
80
    """
81
    img = mmcv.imread(img)
82
    img = img.copy()
83
    if palette is None:
84
        palette = np.random.randint(0, 255, size=(len(class_names), 3))
85
    palette = np.array(palette)
86
    assert palette.shape[0] == len(class_names)
87
    assert palette.shape[1] == 3
88
    assert len(palette.shape) == 2
89
    assert 0 < opacity <= 1.0
90
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
91
    for label, color in enumerate(palette):
92
        color_seg[seg == label, :] = color
93
    # convert to BGR
94
    color_seg = color_seg[..., ::-1]
95
96
    img = img * (1 - opacity) + color_seg * opacity
97
    img = img.astype(np.uint8)
98
    # if out_file specified, do not show image in window
99
    if out_file is not None:
100
        show = False
101
102
    if show:
103
        mmcv.imshow(img, win_name, wait_time)
104
    if out_file is not None:
105
        mmcv.imwrite(img, out_file)
106
107
    if not (show or out_file):
108
        warnings.warn('show==False and out_file is not specified, only '
109
                      'result image will be returned')
110
        return img
111
112
113
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
114
    if show_origin is True:
115
        # only keep pipeline of Loading data and ann
116
        _data_cfg['pipeline'] = [
117
            x for x in _data_cfg.pipeline if 'Load' in x['type']
118
        ]
119
    else:
120
        _data_cfg['pipeline'] = [
121
            x for x in _data_cfg.pipeline if x['type'] not in skip_type
122
        ]
123
124
125
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
126
    cfg = Config.fromfile(config_path)
127
    train_data_cfg = cfg.data.train
128
    if isinstance(train_data_cfg, list):
129
        for _data_cfg in train_data_cfg:
130
            if 'pipeline' in _data_cfg:
131
                _retrieve_data_cfg(_data_cfg, skip_type, show_origin)
132
            elif 'dataset' in _data_cfg:
133
                _retrieve_data_cfg(_data_cfg['dataset'], skip_type,
134
                                   show_origin)
135
            else:
136
                raise ValueError
137
    elif 'dataset' in train_data_cfg:
138
        _retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
139
    else:
140
        _retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
141
    return cfg
142
143
144
def main():
145
    args = parse_args()
146
    cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
147
    dataset = build_dataset(cfg.data.train)
148
    progress_bar = mmcv.ProgressBar(len(dataset))
149
    for item in dataset:
150
        filename = os.path.join(args.output_dir,
151
                                Path(item['filename']).name
152
                                ) if args.output_dir is not None else None
153
        imshow_semantic(
154
            item['img'],
155
            item['gt_semantic_seg'],
156
            dataset.CLASSES,
157
            dataset.PALETTE,
158
            show=args.show,
159
            wait_time=args.show_interval,
160
            out_file=filename,
161
            opacity=args.opacity,
162
        )
163
        progress_bar.update()
164
165
166
if __name__ == '__main__':
167
    main()