Diff of /tests/test_config.py [000000] .. [4e96d3]

Switch to unified view

a b/tests/test_config.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import glob
3
import os
4
from os.path import dirname, exists, isdir, join, relpath
5
6
from mmcv import Config
7
from torch import nn
8
9
from mmseg.models import build_segmentor
10
11
12
def _get_config_directory():
13
    """Find the predefined segmentor config directory."""
14
    try:
15
        # Assume we are running in the source mmsegmentation repo
16
        repo_dpath = dirname(dirname(__file__))
17
    except NameError:
18
        # For IPython development when this __file__ is not defined
19
        import mmseg
20
        repo_dpath = dirname(dirname(mmseg.__file__))
21
    config_dpath = join(repo_dpath, 'configs')
22
    if not exists(config_dpath):
23
        raise Exception('Cannot find config path')
24
    return config_dpath
25
26
27
def test_config_build_segmentor():
28
    """Test that all segmentation models defined in the configs can be
29
    initialized."""
30
    config_dpath = _get_config_directory()
31
    print('Found config_dpath = {!r}'.format(config_dpath))
32
33
    config_fpaths = []
34
    # one config each sub folder
35
    for sub_folder in os.listdir(config_dpath):
36
        if isdir(sub_folder):
37
            config_fpaths.append(
38
                list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0])
39
    config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
40
    config_names = [relpath(p, config_dpath) for p in config_fpaths]
41
42
    print('Using {} config files'.format(len(config_names)))
43
44
    for config_fname in config_names:
45
        config_fpath = join(config_dpath, config_fname)
46
        config_mod = Config.fromfile(config_fpath)
47
48
        config_mod.model
49
        print('Building segmentor, config_fpath = {!r}'.format(config_fpath))
50
51
        # Remove pretrained keys to allow for testing in an offline environment
52
        if 'pretrained' in config_mod.model:
53
            config_mod.model['pretrained'] = None
54
55
        print('building {}'.format(config_fname))
56
        segmentor = build_segmentor(config_mod.model)
57
        assert segmentor is not None
58
59
        head_config = config_mod.model['decode_head']
60
        _check_decode_head(head_config, segmentor.decode_head)
61
62
63
def test_config_data_pipeline():
64
    """Test whether the data pipeline is valid and can process corner cases.
65
66
    CommandLine:
67
        xdoctest -m tests/test_config.py test_config_build_data_pipeline
68
    """
69
    from mmcv import Config
70
    from mmseg.datasets.pipelines import Compose
71
    import numpy as np
72
73
    config_dpath = _get_config_directory()
74
    print('Found config_dpath = {!r}'.format(config_dpath))
75
76
    import glob
77
    config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
78
    config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
79
    config_names = [relpath(p, config_dpath) for p in config_fpaths]
80
81
    print('Using {} config files'.format(len(config_names)))
82
83
    for config_fname in config_names:
84
        config_fpath = join(config_dpath, config_fname)
85
        print(
86
            'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
87
        config_mod = Config.fromfile(config_fpath)
88
89
        # remove loading pipeline
90
        load_img_pipeline = config_mod.train_pipeline.pop(0)
91
        to_float32 = load_img_pipeline.get('to_float32', False)
92
        config_mod.train_pipeline.pop(0)
93
        config_mod.test_pipeline.pop(0)
94
95
        train_pipeline = Compose(config_mod.train_pipeline)
96
        test_pipeline = Compose(config_mod.test_pipeline)
97
98
        img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
99
        if to_float32:
100
            img = img.astype(np.float32)
101
        seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
102
103
        results = dict(
104
            filename='test_img.png',
105
            ori_filename='test_img.png',
106
            img=img,
107
            img_shape=img.shape,
108
            ori_shape=img.shape,
109
            gt_semantic_seg=seg)
110
        results['seg_fields'] = ['gt_semantic_seg']
111
112
        print('Test training data pipeline: \n{!r}'.format(train_pipeline))
113
        output_results = train_pipeline(results)
114
        assert output_results is not None
115
116
        results = dict(
117
            filename='test_img.png',
118
            ori_filename='test_img.png',
119
            img=img,
120
            img_shape=img.shape,
121
            ori_shape=img.shape,
122
        )
123
        print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
124
        output_results = test_pipeline(results)
125
        assert output_results is not None
126
127
128
def _check_decode_head(decode_head_cfg, decode_head):
129
    if isinstance(decode_head_cfg, list):
130
        assert isinstance(decode_head, nn.ModuleList)
131
        assert len(decode_head_cfg) == len(decode_head)
132
        num_heads = len(decode_head)
133
        for i in range(num_heads):
134
            _check_decode_head(decode_head_cfg[i], decode_head[i])
135
        return
136
    # check consistency between head_config and roi_head
137
    assert decode_head_cfg['type'] == decode_head.__class__.__name__
138
139
    assert decode_head_cfg['type'] == decode_head.__class__.__name__
140
141
    in_channels = decode_head_cfg.in_channels
142
    input_transform = decode_head.input_transform
143
    assert input_transform in ['resize_concat', 'multiple_select', None]
144
    if input_transform is not None:
145
        assert isinstance(in_channels, (list, tuple))
146
        assert isinstance(decode_head.in_index, (list, tuple))
147
        assert len(in_channels) == len(decode_head.in_index)
148
    elif input_transform == 'resize_concat':
149
        assert sum(in_channels) == decode_head.in_channels
150
    else:
151
        assert isinstance(in_channels, int)
152
        assert in_channels == decode_head.in_channels
153
        assert isinstance(decode_head.in_index, int)
154
155
    if decode_head_cfg['type'] == 'PointHead':
156
        assert decode_head_cfg.channels+decode_head_cfg.num_classes == \
157
               decode_head.fc_seg.in_channels
158
        assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes
159
    else:
160
        assert decode_head_cfg.channels == decode_head.conv_seg.in_channels
161
        assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes