Diff of /tools/pytorch2onnx.py [000000] .. [4e96d3]

Switch to unified view

a b/tools/pytorch2onnx.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
from functools import partial
4
5
import mmcv
6
import numpy as np
7
import onnxruntime as rt
8
import torch
9
import torch._C
10
import torch.serialization
11
from mmcv import DictAction
12
from mmcv.onnx import register_extra_symbolics
13
from mmcv.runner import load_checkpoint
14
from torch import nn
15
16
from mmseg.apis import show_result_pyplot
17
from mmseg.apis.inference import LoadImage
18
from mmseg.datasets.pipelines import Compose
19
from mmseg.models import build_segmentor
20
from mmseg.ops import resize
21
22
torch.manual_seed(3)
23
24
25
def _convert_batchnorm(module):
26
    module_output = module
27
    if isinstance(module, torch.nn.SyncBatchNorm):
28
        module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
29
                                             module.momentum, module.affine,
30
                                             module.track_running_stats)
31
        if module.affine:
32
            module_output.weight.data = module.weight.data.clone().detach()
33
            module_output.bias.data = module.bias.data.clone().detach()
34
            # keep requires_grad unchanged
35
            module_output.weight.requires_grad = module.weight.requires_grad
36
            module_output.bias.requires_grad = module.bias.requires_grad
37
        module_output.running_mean = module.running_mean
38
        module_output.running_var = module.running_var
39
        module_output.num_batches_tracked = module.num_batches_tracked
40
    for name, child in module.named_children():
41
        module_output.add_module(name, _convert_batchnorm(child))
42
    del module
43
    return module_output
44
45
46
def _demo_mm_inputs(input_shape, num_classes):
47
    """Create a superset of inputs needed to run test or train batches.
48
49
    Args:
50
        input_shape (tuple):
51
            input batch dimensions
52
        num_classes (int):
53
            number of semantic classes
54
    """
55
    (N, C, H, W) = input_shape
56
    rng = np.random.RandomState(0)
57
    imgs = rng.rand(*input_shape)
58
    segs = rng.randint(
59
        low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
60
    img_metas = [{
61
        'img_shape': (H, W, C),
62
        'ori_shape': (H, W, C),
63
        'pad_shape': (H, W, C),
64
        'filename': '<demo>.png',
65
        'scale_factor': 1.0,
66
        'flip': False,
67
    } for _ in range(N)]
68
    mm_inputs = {
69
        'imgs': torch.FloatTensor(imgs).requires_grad_(True),
70
        'img_metas': img_metas,
71
        'gt_semantic_seg': torch.LongTensor(segs)
72
    }
73
    return mm_inputs
74
75
76
def _prepare_input_img(img_path,
77
                       test_pipeline,
78
                       shape=None,
79
                       rescale_shape=None):
80
    # build the data pipeline
81
    if shape is not None:
82
        test_pipeline[1]['img_scale'] = (shape[1], shape[0])
83
    test_pipeline[1]['transforms'][0]['keep_ratio'] = False
84
    test_pipeline = [LoadImage()] + test_pipeline[1:]
85
    test_pipeline = Compose(test_pipeline)
86
    # prepare data
87
    data = dict(img=img_path)
88
    data = test_pipeline(data)
89
    imgs = data['img']
90
    img_metas = [i.data for i in data['img_metas']]
91
92
    if rescale_shape is not None:
93
        for img_meta in img_metas:
94
            img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
95
96
    mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
97
98
    return mm_inputs
99
100
101
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
102
    # update img and its meta list
103
    N, C, H, W = img_list[0].shape
104
    img_meta = img_meta_list[0][0]
105
    img_shape = (H, W, C)
106
    if update_ori_shape:
107
        ori_shape = img_shape
108
    else:
109
        ori_shape = img_meta['ori_shape']
110
    pad_shape = img_shape
111
    new_img_meta_list = [[{
112
        'img_shape':
113
        img_shape,
114
        'ori_shape':
115
        ori_shape,
116
        'pad_shape':
117
        pad_shape,
118
        'filename':
119
        img_meta['filename'],
120
        'scale_factor':
121
        (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
122
        'flip':
123
        False,
124
    } for _ in range(N)]]
125
126
    return img_list, new_img_meta_list
127
128
129
def pytorch2onnx(model,
130
                 mm_inputs,
131
                 opset_version=11,
132
                 show=False,
133
                 output_file='tmp.onnx',
134
                 verify=False,
135
                 dynamic_export=False):
136
    """Export Pytorch model to ONNX model and verify the outputs are same
137
    between Pytorch and ONNX.
138
139
    Args:
140
        model (nn.Module): Pytorch model we want to export.
141
        mm_inputs (dict): Contain the input tensors and img_metas information.
142
        opset_version (int): The onnx op version. Default: 11.
143
        show (bool): Whether print the computation graph. Default: False.
144
        output_file (string): The path to where we store the output ONNX model.
145
            Default: `tmp.onnx`.
146
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
147
            Default: False.
148
        dynamic_export (bool): Whether to export ONNX with dynamic axis.
149
            Default: False.
150
    """
151
    model.cpu().eval()
152
    test_mode = model.test_cfg.mode
153
154
    if isinstance(model.decode_head, nn.ModuleList):
155
        num_classes = model.decode_head[-1].num_classes
156
    else:
157
        num_classes = model.decode_head.num_classes
158
159
    imgs = mm_inputs.pop('imgs')
160
    img_metas = mm_inputs.pop('img_metas')
161
162
    img_list = [img[None, :] for img in imgs]
163
    img_meta_list = [[img_meta] for img_meta in img_metas]
164
    # update img_meta
165
    img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
166
167
    # replace original forward function
168
    origin_forward = model.forward
169
    model.forward = partial(
170
        model.forward,
171
        img_metas=img_meta_list,
172
        return_loss=False,
173
        rescale=True)
174
    dynamic_axes = None
175
    if dynamic_export:
176
        if test_mode == 'slide':
177
            dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
178
        else:
179
            dynamic_axes = {
180
                'input': {
181
                    0: 'batch',
182
                    2: 'height',
183
                    3: 'width'
184
                },
185
                'output': {
186
                    1: 'batch',
187
                    2: 'height',
188
                    3: 'width'
189
                }
190
            }
191
192
    register_extra_symbolics(opset_version)
193
    with torch.no_grad():
194
        torch.onnx.export(
195
            model, (img_list, ),
196
            output_file,
197
            input_names=['input'],
198
            output_names=['output'],
199
            export_params=True,
200
            keep_initializers_as_inputs=False,
201
            verbose=show,
202
            opset_version=opset_version,
203
            dynamic_axes=dynamic_axes)
204
        print(f'Successfully exported ONNX model: {output_file}')
205
    model.forward = origin_forward
206
207
    if verify:
208
        # check by onnx
209
        import onnx
210
        onnx_model = onnx.load(output_file)
211
        onnx.checker.check_model(onnx_model)
212
213
        if dynamic_export and test_mode == 'whole':
214
            # scale image for dynamic shape test
215
            img_list = [resize(_, scale_factor=1.5) for _ in img_list]
216
            # concate flip image for batch test
217
            flip_img_list = [_.flip(-1) for _ in img_list]
218
            img_list = [
219
                torch.cat((ori_img, flip_img), 0)
220
                for ori_img, flip_img in zip(img_list, flip_img_list)
221
            ]
222
223
            # update img_meta
224
            img_list, img_meta_list = _update_input_img(
225
                img_list, img_meta_list, test_mode == 'whole')
226
227
        # check the numerical value
228
        # get pytorch output
229
        with torch.no_grad():
230
            pytorch_result = model(img_list, img_meta_list, return_loss=False)
231
            pytorch_result = np.stack(pytorch_result, 0)
232
233
        # get onnx output
234
        input_all = [node.name for node in onnx_model.graph.input]
235
        input_initializer = [
236
            node.name for node in onnx_model.graph.initializer
237
        ]
238
        net_feed_input = list(set(input_all) - set(input_initializer))
239
        assert (len(net_feed_input) == 1)
240
        sess = rt.InferenceSession(output_file)
241
        onnx_result = sess.run(
242
            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
243
        # show segmentation results
244
        if show:
245
            import cv2
246
            import os.path as osp
247
            img = img_meta_list[0][0]['filename']
248
            if not osp.exists(img):
249
                img = imgs[0][:3, ...].permute(1, 2, 0) * 255
250
                img = img.detach().numpy().astype(np.uint8)
251
                ori_shape = img.shape[:2]
252
            else:
253
                ori_shape = LoadImage()({'img': img})['ori_shape']
254
255
            # resize onnx_result to ori_shape
256
            onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
257
                                      (ori_shape[1], ori_shape[0]))
258
            show_result_pyplot(
259
                model,
260
                img, (onnx_result_, ),
261
                palette=model.PALETTE,
262
                block=False,
263
                title='ONNXRuntime',
264
                opacity=0.5)
265
266
            # resize pytorch_result to ori_shape
267
            pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
268
                                         (ori_shape[1], ori_shape[0]))
269
            show_result_pyplot(
270
                model,
271
                img, (pytorch_result_, ),
272
                title='PyTorch',
273
                palette=model.PALETTE,
274
                opacity=0.5)
275
        # compare results
276
        np.testing.assert_allclose(
277
            pytorch_result.astype(np.float32) / num_classes,
278
            onnx_result.astype(np.float32) / num_classes,
279
            rtol=1e-5,
280
            atol=1e-5,
281
            err_msg='The outputs are different between Pytorch and ONNX')
282
        print('The outputs are same between Pytorch and ONNX')
283
284
285
def parse_args():
286
    parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
287
    parser.add_argument('config', help='test config file path')
288
    parser.add_argument('--checkpoint', help='checkpoint file', default=None)
289
    parser.add_argument(
290
        '--input-img', type=str, help='Images for input', default=None)
291
    parser.add_argument(
292
        '--show',
293
        action='store_true',
294
        help='show onnx graph and segmentation results')
295
    parser.add_argument(
296
        '--verify', action='store_true', help='verify the onnx model')
297
    parser.add_argument('--output-file', type=str, default='tmp.onnx')
298
    parser.add_argument('--opset-version', type=int, default=11)
299
    parser.add_argument(
300
        '--shape',
301
        type=int,
302
        nargs='+',
303
        default=None,
304
        help='input image height and width.')
305
    parser.add_argument(
306
        '--rescale_shape',
307
        type=int,
308
        nargs='+',
309
        default=None,
310
        help='output image rescale height and width, work for slide mode.')
311
    parser.add_argument(
312
        '--cfg-options',
313
        nargs='+',
314
        action=DictAction,
315
        help='Override some settings in the used config, the key-value pair '
316
        'in xxx=yyy format will be merged into config file. If the value to '
317
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
318
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
319
        'Note that the quotation marks are necessary and that no white space '
320
        'is allowed.')
321
    parser.add_argument(
322
        '--dynamic-export',
323
        action='store_true',
324
        help='Whether to export onnx with dynamic axis.')
325
    args = parser.parse_args()
326
    return args
327
328
329
if __name__ == '__main__':
330
    args = parse_args()
331
332
    cfg = mmcv.Config.fromfile(args.config)
333
    if args.cfg_options is not None:
334
        cfg.merge_from_dict(args.cfg_options)
335
    cfg.model.pretrained = None
336
337
    if args.shape is None:
338
        img_scale = cfg.test_pipeline[1]['img_scale']
339
        input_shape = (1, 3, img_scale[1], img_scale[0])
340
    elif len(args.shape) == 1:
341
        input_shape = (1, 3, args.shape[0], args.shape[0])
342
    elif len(args.shape) == 2:
343
        input_shape = (
344
            1,
345
            3,
346
        ) + tuple(args.shape)
347
    else:
348
        raise ValueError('invalid input shape')
349
350
    test_mode = cfg.model.test_cfg.mode
351
352
    # build the model and load checkpoint
353
    cfg.model.train_cfg = None
354
    segmentor = build_segmentor(
355
        cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
356
    # convert SyncBN to BN
357
    segmentor = _convert_batchnorm(segmentor)
358
359
    if args.checkpoint:
360
        checkpoint = load_checkpoint(
361
            segmentor, args.checkpoint, map_location='cpu')
362
        segmentor.CLASSES = checkpoint['meta']['CLASSES']
363
        segmentor.PALETTE = checkpoint['meta']['PALETTE']
364
365
    # read input or create dummpy input
366
    if args.input_img is not None:
367
        preprocess_shape = (input_shape[2], input_shape[3])
368
        rescale_shape = None
369
        if args.rescale_shape is not None:
370
            rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
371
        mm_inputs = _prepare_input_img(
372
            args.input_img,
373
            cfg.data.test.pipeline,
374
            shape=preprocess_shape,
375
            rescale_shape=rescale_shape)
376
    else:
377
        if isinstance(segmentor.decode_head, nn.ModuleList):
378
            num_classes = segmentor.decode_head[-1].num_classes
379
        else:
380
            num_classes = segmentor.decode_head.num_classes
381
        mm_inputs = _demo_mm_inputs(input_shape, num_classes)
382
383
    # convert model to onnx file
384
    pytorch2onnx(
385
        segmentor,
386
        mm_inputs,
387
        opset_version=args.opset_version,
388
        show=args.show,
389
        output_file=args.output_file,
390
        verify=args.verify,
391
        dynamic_export=args.dynamic_export)