Diff of /tools/deploy_test.py [000000] .. [4e96d3]

Switch to unified view

a b/tools/deploy_test.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os
4
import os.path as osp
5
import shutil
6
import warnings
7
from typing import Any, Iterable
8
9
import mmcv
10
import numpy as np
11
import torch
12
from mmcv.parallel import MMDataParallel
13
from mmcv.runner import get_dist_info
14
from mmcv.utils import DictAction
15
16
from mmseg.apis import single_gpu_test
17
from mmseg.datasets import build_dataloader, build_dataset
18
from mmseg.models.segmentors.base import BaseSegmentor
19
from mmseg.ops import resize
20
21
22
class ONNXRuntimeSegmentor(BaseSegmentor):
23
24
    def __init__(self, onnx_file: str, cfg: Any, device_id: int):
25
        super(ONNXRuntimeSegmentor, self).__init__()
26
        import onnxruntime as ort
27
28
        # get the custom op path
29
        ort_custom_op_path = ''
30
        try:
31
            from mmcv.ops import get_onnxruntime_op_path
32
            ort_custom_op_path = get_onnxruntime_op_path()
33
        except (ImportError, ModuleNotFoundError):
34
            warnings.warn('If input model has custom op from mmcv, \
35
                you may have to build mmcv with ONNXRuntime from source.')
36
        session_options = ort.SessionOptions()
37
        # register custom op for onnxruntime
38
        if osp.exists(ort_custom_op_path):
39
            session_options.register_custom_ops_library(ort_custom_op_path)
40
        sess = ort.InferenceSession(onnx_file, session_options)
41
        providers = ['CPUExecutionProvider']
42
        options = [{}]
43
        is_cuda_available = ort.get_device() == 'GPU'
44
        if is_cuda_available:
45
            providers.insert(0, 'CUDAExecutionProvider')
46
            options.insert(0, {'device_id': device_id})
47
48
        sess.set_providers(providers, options)
49
50
        self.sess = sess
51
        self.device_id = device_id
52
        self.io_binding = sess.io_binding()
53
        self.output_names = [_.name for _ in sess.get_outputs()]
54
        for name in self.output_names:
55
            self.io_binding.bind_output(name)
56
        self.cfg = cfg
57
        self.test_mode = cfg.model.test_cfg.mode
58
        self.is_cuda_available = is_cuda_available
59
60
    def extract_feat(self, imgs):
61
        raise NotImplementedError('This method is not implemented.')
62
63
    def encode_decode(self, img, img_metas):
64
        raise NotImplementedError('This method is not implemented.')
65
66
    def forward_train(self, imgs, img_metas, **kwargs):
67
        raise NotImplementedError('This method is not implemented.')
68
69
    def simple_test(self, img: torch.Tensor, img_meta: Iterable,
70
                    **kwargs) -> list:
71
        if not self.is_cuda_available:
72
            img = img.detach().cpu()
73
        elif self.device_id >= 0:
74
            img = img.cuda(self.device_id)
75
        device_type = img.device.type
76
        self.io_binding.bind_input(
77
            name='input',
78
            device_type=device_type,
79
            device_id=self.device_id,
80
            element_type=np.float32,
81
            shape=img.shape,
82
            buffer_ptr=img.data_ptr())
83
        self.sess.run_with_iobinding(self.io_binding)
84
        seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
85
        # whole might support dynamic reshape
86
        ori_shape = img_meta[0]['ori_shape']
87
        if not (ori_shape[0] == seg_pred.shape[-2]
88
                and ori_shape[1] == seg_pred.shape[-1]):
89
            seg_pred = torch.from_numpy(seg_pred).float()
90
            seg_pred = resize(
91
                seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
92
            seg_pred = seg_pred.long().detach().cpu().numpy()
93
        seg_pred = seg_pred[0]
94
        seg_pred = list(seg_pred)
95
        return seg_pred
96
97
    def aug_test(self, imgs, img_metas, **kwargs):
98
        raise NotImplementedError('This method is not implemented.')
99
100
101
class TensorRTSegmentor(BaseSegmentor):
102
103
    def __init__(self, trt_file: str, cfg: Any, device_id: int):
104
        super(TensorRTSegmentor, self).__init__()
105
        from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
106
        try:
107
            load_tensorrt_plugin()
108
        except (ImportError, ModuleNotFoundError):
109
            warnings.warn('If input model has custom op from mmcv, \
110
                you may have to build mmcv with TensorRT from source.')
111
        model = TRTWraper(
112
            trt_file, input_names=['input'], output_names=['output'])
113
114
        self.model = model
115
        self.device_id = device_id
116
        self.cfg = cfg
117
        self.test_mode = cfg.model.test_cfg.mode
118
119
    def extract_feat(self, imgs):
120
        raise NotImplementedError('This method is not implemented.')
121
122
    def encode_decode(self, img, img_metas):
123
        raise NotImplementedError('This method is not implemented.')
124
125
    def forward_train(self, imgs, img_metas, **kwargs):
126
        raise NotImplementedError('This method is not implemented.')
127
128
    def simple_test(self, img: torch.Tensor, img_meta: Iterable,
129
                    **kwargs) -> list:
130
        with torch.cuda.device(self.device_id), torch.no_grad():
131
            seg_pred = self.model({'input': img})['output']
132
        seg_pred = seg_pred.detach().cpu().numpy()
133
        # whole might support dynamic reshape
134
        ori_shape = img_meta[0]['ori_shape']
135
        if not (ori_shape[0] == seg_pred.shape[-2]
136
                and ori_shape[1] == seg_pred.shape[-1]):
137
            seg_pred = torch.from_numpy(seg_pred).float()
138
            seg_pred = resize(
139
                seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
140
            seg_pred = seg_pred.long().detach().cpu().numpy()
141
        seg_pred = seg_pred[0]
142
        seg_pred = list(seg_pred)
143
        return seg_pred
144
145
    def aug_test(self, imgs, img_metas, **kwargs):
146
        raise NotImplementedError('This method is not implemented.')
147
148
149
def parse_args() -> argparse.Namespace:
150
    parser = argparse.ArgumentParser(
151
        description='mmseg backend test (and eval)')
152
    parser.add_argument('config', help='test config file path')
153
    parser.add_argument('model', help='Input model file')
154
    parser.add_argument(
155
        '--backend',
156
        help='Backend of the model.',
157
        choices=['onnxruntime', 'tensorrt'])
158
    parser.add_argument('--out', help='output result file in pickle format')
159
    parser.add_argument(
160
        '--format-only',
161
        action='store_true',
162
        help='Format the output results without perform evaluation. It is'
163
        'useful when you want to format the result to a specific format and '
164
        'submit it to the test server')
165
    parser.add_argument(
166
        '--eval',
167
        type=str,
168
        nargs='+',
169
        help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
170
        ' for generic datasets, and "cityscapes" for Cityscapes')
171
    parser.add_argument('--show', action='store_true', help='show results')
172
    parser.add_argument(
173
        '--show-dir', help='directory where painted images will be saved')
174
    parser.add_argument(
175
        '--options',
176
        nargs='+',
177
        action=DictAction,
178
        help="--options is deprecated in favor of --cfg_options' and it will "
179
        'not be supported in version v0.22.0. Override some settings in the '
180
        'used config, the key-value pair in xxx=yyy format will be merged '
181
        'into config file. If the value to be overwritten is a list, it '
182
        'should be like key="[a,b]" or key=a,b It also allows nested '
183
        'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
184
        'marks are necessary and that no white space is allowed.')
185
    parser.add_argument(
186
        '--cfg-options',
187
        nargs='+',
188
        action=DictAction,
189
        help='override some settings in the used config, the key-value pair '
190
        'in xxx=yyy format will be merged into config file. If the value to '
191
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
192
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
193
        'Note that the quotation marks are necessary and that no white space '
194
        'is allowed.')
195
    parser.add_argument(
196
        '--eval-options',
197
        nargs='+',
198
        action=DictAction,
199
        help='custom options for evaluation')
200
    parser.add_argument(
201
        '--opacity',
202
        type=float,
203
        default=0.5,
204
        help='Opacity of painted segmentation map. In (0, 1] range.')
205
    parser.add_argument('--local_rank', type=int, default=0)
206
    args = parser.parse_args()
207
    if 'LOCAL_RANK' not in os.environ:
208
        os.environ['LOCAL_RANK'] = str(args.local_rank)
209
210
    if args.options and args.cfg_options:
211
        raise ValueError(
212
            '--options and --cfg-options cannot be both '
213
            'specified, --options is deprecated in favor of --cfg-options. '
214
            '--options will not be supported in version v0.22.0.')
215
    if args.options:
216
        warnings.warn('--options is deprecated in favor of --cfg-options. '
217
                      '--options will not be supported in version v0.22.0.')
218
        args.cfg_options = args.options
219
220
    return args
221
222
223
def main():
224
    args = parse_args()
225
226
    assert args.out or args.eval or args.format_only or args.show \
227
        or args.show_dir, \
228
        ('Please specify at least one operation (save/eval/format/show the '
229
         'results / save the results) with the argument "--out", "--eval"'
230
         ', "--format-only", "--show" or "--show-dir"')
231
232
    if args.eval and args.format_only:
233
        raise ValueError('--eval and --format_only cannot be both specified')
234
235
    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
236
        raise ValueError('The output file must be a pkl file.')
237
238
    cfg = mmcv.Config.fromfile(args.config)
239
    if args.cfg_options is not None:
240
        cfg.merge_from_dict(args.cfg_options)
241
    cfg.model.pretrained = None
242
    cfg.data.test.test_mode = True
243
244
    # init distributed env first, since logger depends on the dist info.
245
    distributed = False
246
247
    # build the dataloader
248
    # TODO: support multiple images per gpu (only minor changes are needed)
249
    dataset = build_dataset(cfg.data.test)
250
    data_loader = build_dataloader(
251
        dataset,
252
        samples_per_gpu=1,
253
        workers_per_gpu=cfg.data.workers_per_gpu,
254
        dist=distributed,
255
        shuffle=False)
256
257
    # load onnx config and meta
258
    cfg.model.train_cfg = None
259
260
    if args.backend == 'onnxruntime':
261
        model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
262
    elif args.backend == 'tensorrt':
263
        model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
264
265
    model.CLASSES = dataset.CLASSES
266
    model.PALETTE = dataset.PALETTE
267
268
    # clean gpu memory when starting a new evaluation.
269
    torch.cuda.empty_cache()
270
    eval_kwargs = {} if args.eval_options is None else args.eval_options
271
272
    # Deprecated
273
    efficient_test = eval_kwargs.get('efficient_test', False)
274
    if efficient_test:
275
        warnings.warn(
276
            '``efficient_test=True`` does not have effect in tools/test.py, '
277
            'the evaluation and format results are CPU memory efficient by '
278
            'default')
279
280
    eval_on_format_results = (
281
        args.eval is not None and 'cityscapes' in args.eval)
282
    if eval_on_format_results:
283
        assert len(args.eval) == 1, 'eval on format results is not ' \
284
                                    'applicable for metrics other than ' \
285
                                    'cityscapes'
286
    if args.format_only or eval_on_format_results:
287
        if 'imgfile_prefix' in eval_kwargs:
288
            tmpdir = eval_kwargs['imgfile_prefix']
289
        else:
290
            tmpdir = '.format_cityscapes'
291
            eval_kwargs.setdefault('imgfile_prefix', tmpdir)
292
        mmcv.mkdir_or_exist(tmpdir)
293
    else:
294
        tmpdir = None
295
296
    model = MMDataParallel(model, device_ids=[0])
297
    results = single_gpu_test(
298
        model,
299
        data_loader,
300
        args.show,
301
        args.show_dir,
302
        False,
303
        args.opacity,
304
        pre_eval=args.eval is not None and not eval_on_format_results,
305
        format_only=args.format_only or eval_on_format_results,
306
        format_args=eval_kwargs)
307
308
    rank, _ = get_dist_info()
309
    if rank == 0:
310
        if args.out:
311
            warnings.warn(
312
                'The behavior of ``args.out`` has been changed since MMSeg '
313
                'v0.16, the pickled outputs could be seg map as type of '
314
                'np.array, pre-eval results or file paths for '
315
                '``dataset.format_results()``.')
316
            print(f'\nwriting results to {args.out}')
317
            mmcv.dump(results, args.out)
318
        if args.eval:
319
            dataset.evaluate(results, args.eval, **eval_kwargs)
320
        if tmpdir is not None and eval_on_format_results:
321
            # remove tmp dir when cityscapes evaluation
322
            shutil.rmtree(tmpdir)
323
324
325
if __name__ == '__main__':
326
    main()