|
a |
|
b/tests/test_config.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import glob |
|
|
3 |
import os |
|
|
4 |
from os.path import dirname, exists, isdir, join, relpath |
|
|
5 |
|
|
|
6 |
from mmcv import Config |
|
|
7 |
from torch import nn |
|
|
8 |
|
|
|
9 |
from mmseg.models import build_segmentor |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def _get_config_directory(): |
|
|
13 |
"""Find the predefined segmentor config directory.""" |
|
|
14 |
try: |
|
|
15 |
# Assume we are running in the source mmsegmentation repo |
|
|
16 |
repo_dpath = dirname(dirname(__file__)) |
|
|
17 |
except NameError: |
|
|
18 |
# For IPython development when this __file__ is not defined |
|
|
19 |
import mmseg |
|
|
20 |
repo_dpath = dirname(dirname(mmseg.__file__)) |
|
|
21 |
config_dpath = join(repo_dpath, 'configs') |
|
|
22 |
if not exists(config_dpath): |
|
|
23 |
raise Exception('Cannot find config path') |
|
|
24 |
return config_dpath |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
def test_config_build_segmentor(): |
|
|
28 |
"""Test that all segmentation models defined in the configs can be |
|
|
29 |
initialized.""" |
|
|
30 |
config_dpath = _get_config_directory() |
|
|
31 |
print('Found config_dpath = {!r}'.format(config_dpath)) |
|
|
32 |
|
|
|
33 |
config_fpaths = [] |
|
|
34 |
# one config each sub folder |
|
|
35 |
for sub_folder in os.listdir(config_dpath): |
|
|
36 |
if isdir(sub_folder): |
|
|
37 |
config_fpaths.append( |
|
|
38 |
list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0]) |
|
|
39 |
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] |
|
|
40 |
config_names = [relpath(p, config_dpath) for p in config_fpaths] |
|
|
41 |
|
|
|
42 |
print('Using {} config files'.format(len(config_names))) |
|
|
43 |
|
|
|
44 |
for config_fname in config_names: |
|
|
45 |
config_fpath = join(config_dpath, config_fname) |
|
|
46 |
config_mod = Config.fromfile(config_fpath) |
|
|
47 |
|
|
|
48 |
config_mod.model |
|
|
49 |
print('Building segmentor, config_fpath = {!r}'.format(config_fpath)) |
|
|
50 |
|
|
|
51 |
# Remove pretrained keys to allow for testing in an offline environment |
|
|
52 |
if 'pretrained' in config_mod.model: |
|
|
53 |
config_mod.model['pretrained'] = None |
|
|
54 |
|
|
|
55 |
print('building {}'.format(config_fname)) |
|
|
56 |
segmentor = build_segmentor(config_mod.model) |
|
|
57 |
assert segmentor is not None |
|
|
58 |
|
|
|
59 |
head_config = config_mod.model['decode_head'] |
|
|
60 |
_check_decode_head(head_config, segmentor.decode_head) |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
def test_config_data_pipeline(): |
|
|
64 |
"""Test whether the data pipeline is valid and can process corner cases. |
|
|
65 |
|
|
|
66 |
CommandLine: |
|
|
67 |
xdoctest -m tests/test_config.py test_config_build_data_pipeline |
|
|
68 |
""" |
|
|
69 |
from mmcv import Config |
|
|
70 |
from mmseg.datasets.pipelines import Compose |
|
|
71 |
import numpy as np |
|
|
72 |
|
|
|
73 |
config_dpath = _get_config_directory() |
|
|
74 |
print('Found config_dpath = {!r}'.format(config_dpath)) |
|
|
75 |
|
|
|
76 |
import glob |
|
|
77 |
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py'))) |
|
|
78 |
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] |
|
|
79 |
config_names = [relpath(p, config_dpath) for p in config_fpaths] |
|
|
80 |
|
|
|
81 |
print('Using {} config files'.format(len(config_names))) |
|
|
82 |
|
|
|
83 |
for config_fname in config_names: |
|
|
84 |
config_fpath = join(config_dpath, config_fname) |
|
|
85 |
print( |
|
|
86 |
'Building data pipeline, config_fpath = {!r}'.format(config_fpath)) |
|
|
87 |
config_mod = Config.fromfile(config_fpath) |
|
|
88 |
|
|
|
89 |
# remove loading pipeline |
|
|
90 |
load_img_pipeline = config_mod.train_pipeline.pop(0) |
|
|
91 |
to_float32 = load_img_pipeline.get('to_float32', False) |
|
|
92 |
config_mod.train_pipeline.pop(0) |
|
|
93 |
config_mod.test_pipeline.pop(0) |
|
|
94 |
|
|
|
95 |
train_pipeline = Compose(config_mod.train_pipeline) |
|
|
96 |
test_pipeline = Compose(config_mod.test_pipeline) |
|
|
97 |
|
|
|
98 |
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8) |
|
|
99 |
if to_float32: |
|
|
100 |
img = img.astype(np.float32) |
|
|
101 |
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8) |
|
|
102 |
|
|
|
103 |
results = dict( |
|
|
104 |
filename='test_img.png', |
|
|
105 |
ori_filename='test_img.png', |
|
|
106 |
img=img, |
|
|
107 |
img_shape=img.shape, |
|
|
108 |
ori_shape=img.shape, |
|
|
109 |
gt_semantic_seg=seg) |
|
|
110 |
results['seg_fields'] = ['gt_semantic_seg'] |
|
|
111 |
|
|
|
112 |
print('Test training data pipeline: \n{!r}'.format(train_pipeline)) |
|
|
113 |
output_results = train_pipeline(results) |
|
|
114 |
assert output_results is not None |
|
|
115 |
|
|
|
116 |
results = dict( |
|
|
117 |
filename='test_img.png', |
|
|
118 |
ori_filename='test_img.png', |
|
|
119 |
img=img, |
|
|
120 |
img_shape=img.shape, |
|
|
121 |
ori_shape=img.shape, |
|
|
122 |
) |
|
|
123 |
print('Test testing data pipeline: \n{!r}'.format(test_pipeline)) |
|
|
124 |
output_results = test_pipeline(results) |
|
|
125 |
assert output_results is not None |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
def _check_decode_head(decode_head_cfg, decode_head): |
|
|
129 |
if isinstance(decode_head_cfg, list): |
|
|
130 |
assert isinstance(decode_head, nn.ModuleList) |
|
|
131 |
assert len(decode_head_cfg) == len(decode_head) |
|
|
132 |
num_heads = len(decode_head) |
|
|
133 |
for i in range(num_heads): |
|
|
134 |
_check_decode_head(decode_head_cfg[i], decode_head[i]) |
|
|
135 |
return |
|
|
136 |
# check consistency between head_config and roi_head |
|
|
137 |
assert decode_head_cfg['type'] == decode_head.__class__.__name__ |
|
|
138 |
|
|
|
139 |
assert decode_head_cfg['type'] == decode_head.__class__.__name__ |
|
|
140 |
|
|
|
141 |
in_channels = decode_head_cfg.in_channels |
|
|
142 |
input_transform = decode_head.input_transform |
|
|
143 |
assert input_transform in ['resize_concat', 'multiple_select', None] |
|
|
144 |
if input_transform is not None: |
|
|
145 |
assert isinstance(in_channels, (list, tuple)) |
|
|
146 |
assert isinstance(decode_head.in_index, (list, tuple)) |
|
|
147 |
assert len(in_channels) == len(decode_head.in_index) |
|
|
148 |
elif input_transform == 'resize_concat': |
|
|
149 |
assert sum(in_channels) == decode_head.in_channels |
|
|
150 |
else: |
|
|
151 |
assert isinstance(in_channels, int) |
|
|
152 |
assert in_channels == decode_head.in_channels |
|
|
153 |
assert isinstance(decode_head.in_index, int) |
|
|
154 |
|
|
|
155 |
if decode_head_cfg['type'] == 'PointHead': |
|
|
156 |
assert decode_head_cfg.channels+decode_head_cfg.num_classes == \ |
|
|
157 |
decode_head.fc_seg.in_channels |
|
|
158 |
assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes |
|
|
159 |
else: |
|
|
160 |
assert decode_head_cfg.channels == decode_head.conv_seg.in_channels |
|
|
161 |
assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes |