Diff of /tools/pytorch2onnx.py [000000] .. [4e96d3]

Switch to side-by-side view

--- a
+++ b/tools/pytorch2onnx.py
@@ -0,0 +1,391 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+from functools import partial
+
+import mmcv
+import numpy as np
+import onnxruntime as rt
+import torch
+import torch._C
+import torch.serialization
+from mmcv import DictAction
+from mmcv.onnx import register_extra_symbolics
+from mmcv.runner import load_checkpoint
+from torch import nn
+
+from mmseg.apis import show_result_pyplot
+from mmseg.apis.inference import LoadImage
+from mmseg.datasets.pipelines import Compose
+from mmseg.models import build_segmentor
+from mmseg.ops import resize
+
+torch.manual_seed(3)
+
+
+def _convert_batchnorm(module):
+    module_output = module
+    if isinstance(module, torch.nn.SyncBatchNorm):
+        module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
+                                             module.momentum, module.affine,
+                                             module.track_running_stats)
+        if module.affine:
+            module_output.weight.data = module.weight.data.clone().detach()
+            module_output.bias.data = module.bias.data.clone().detach()
+            # keep requires_grad unchanged
+            module_output.weight.requires_grad = module.weight.requires_grad
+            module_output.bias.requires_grad = module.bias.requires_grad
+        module_output.running_mean = module.running_mean
+        module_output.running_var = module.running_var
+        module_output.num_batches_tracked = module.num_batches_tracked
+    for name, child in module.named_children():
+        module_output.add_module(name, _convert_batchnorm(child))
+    del module
+    return module_output
+
+
+def _demo_mm_inputs(input_shape, num_classes):
+    """Create a superset of inputs needed to run test or train batches.
+
+    Args:
+        input_shape (tuple):
+            input batch dimensions
+        num_classes (int):
+            number of semantic classes
+    """
+    (N, C, H, W) = input_shape
+    rng = np.random.RandomState(0)
+    imgs = rng.rand(*input_shape)
+    segs = rng.randint(
+        low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
+    img_metas = [{
+        'img_shape': (H, W, C),
+        'ori_shape': (H, W, C),
+        'pad_shape': (H, W, C),
+        'filename': '<demo>.png',
+        'scale_factor': 1.0,
+        'flip': False,
+    } for _ in range(N)]
+    mm_inputs = {
+        'imgs': torch.FloatTensor(imgs).requires_grad_(True),
+        'img_metas': img_metas,
+        'gt_semantic_seg': torch.LongTensor(segs)
+    }
+    return mm_inputs
+
+
+def _prepare_input_img(img_path,
+                       test_pipeline,
+                       shape=None,
+                       rescale_shape=None):
+    # build the data pipeline
+    if shape is not None:
+        test_pipeline[1]['img_scale'] = (shape[1], shape[0])
+    test_pipeline[1]['transforms'][0]['keep_ratio'] = False
+    test_pipeline = [LoadImage()] + test_pipeline[1:]
+    test_pipeline = Compose(test_pipeline)
+    # prepare data
+    data = dict(img=img_path)
+    data = test_pipeline(data)
+    imgs = data['img']
+    img_metas = [i.data for i in data['img_metas']]
+
+    if rescale_shape is not None:
+        for img_meta in img_metas:
+            img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
+
+    mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
+
+    return mm_inputs
+
+
+def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
+    # update img and its meta list
+    N, C, H, W = img_list[0].shape
+    img_meta = img_meta_list[0][0]
+    img_shape = (H, W, C)
+    if update_ori_shape:
+        ori_shape = img_shape
+    else:
+        ori_shape = img_meta['ori_shape']
+    pad_shape = img_shape
+    new_img_meta_list = [[{
+        'img_shape':
+        img_shape,
+        'ori_shape':
+        ori_shape,
+        'pad_shape':
+        pad_shape,
+        'filename':
+        img_meta['filename'],
+        'scale_factor':
+        (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
+        'flip':
+        False,
+    } for _ in range(N)]]
+
+    return img_list, new_img_meta_list
+
+
+def pytorch2onnx(model,
+                 mm_inputs,
+                 opset_version=11,
+                 show=False,
+                 output_file='tmp.onnx',
+                 verify=False,
+                 dynamic_export=False):
+    """Export Pytorch model to ONNX model and verify the outputs are same
+    between Pytorch and ONNX.
+
+    Args:
+        model (nn.Module): Pytorch model we want to export.
+        mm_inputs (dict): Contain the input tensors and img_metas information.
+        opset_version (int): The onnx op version. Default: 11.
+        show (bool): Whether print the computation graph. Default: False.
+        output_file (string): The path to where we store the output ONNX model.
+            Default: `tmp.onnx`.
+        verify (bool): Whether compare the outputs between Pytorch and ONNX.
+            Default: False.
+        dynamic_export (bool): Whether to export ONNX with dynamic axis.
+            Default: False.
+    """
+    model.cpu().eval()
+    test_mode = model.test_cfg.mode
+
+    if isinstance(model.decode_head, nn.ModuleList):
+        num_classes = model.decode_head[-1].num_classes
+    else:
+        num_classes = model.decode_head.num_classes
+
+    imgs = mm_inputs.pop('imgs')
+    img_metas = mm_inputs.pop('img_metas')
+
+    img_list = [img[None, :] for img in imgs]
+    img_meta_list = [[img_meta] for img_meta in img_metas]
+    # update img_meta
+    img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
+
+    # replace original forward function
+    origin_forward = model.forward
+    model.forward = partial(
+        model.forward,
+        img_metas=img_meta_list,
+        return_loss=False,
+        rescale=True)
+    dynamic_axes = None
+    if dynamic_export:
+        if test_mode == 'slide':
+            dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
+        else:
+            dynamic_axes = {
+                'input': {
+                    0: 'batch',
+                    2: 'height',
+                    3: 'width'
+                },
+                'output': {
+                    1: 'batch',
+                    2: 'height',
+                    3: 'width'
+                }
+            }
+
+    register_extra_symbolics(opset_version)
+    with torch.no_grad():
+        torch.onnx.export(
+            model, (img_list, ),
+            output_file,
+            input_names=['input'],
+            output_names=['output'],
+            export_params=True,
+            keep_initializers_as_inputs=False,
+            verbose=show,
+            opset_version=opset_version,
+            dynamic_axes=dynamic_axes)
+        print(f'Successfully exported ONNX model: {output_file}')
+    model.forward = origin_forward
+
+    if verify:
+        # check by onnx
+        import onnx
+        onnx_model = onnx.load(output_file)
+        onnx.checker.check_model(onnx_model)
+
+        if dynamic_export and test_mode == 'whole':
+            # scale image for dynamic shape test
+            img_list = [resize(_, scale_factor=1.5) for _ in img_list]
+            # concate flip image for batch test
+            flip_img_list = [_.flip(-1) for _ in img_list]
+            img_list = [
+                torch.cat((ori_img, flip_img), 0)
+                for ori_img, flip_img in zip(img_list, flip_img_list)
+            ]
+
+            # update img_meta
+            img_list, img_meta_list = _update_input_img(
+                img_list, img_meta_list, test_mode == 'whole')
+
+        # check the numerical value
+        # get pytorch output
+        with torch.no_grad():
+            pytorch_result = model(img_list, img_meta_list, return_loss=False)
+            pytorch_result = np.stack(pytorch_result, 0)
+
+        # get onnx output
+        input_all = [node.name for node in onnx_model.graph.input]
+        input_initializer = [
+            node.name for node in onnx_model.graph.initializer
+        ]
+        net_feed_input = list(set(input_all) - set(input_initializer))
+        assert (len(net_feed_input) == 1)
+        sess = rt.InferenceSession(output_file)
+        onnx_result = sess.run(
+            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
+        # show segmentation results
+        if show:
+            import cv2
+            import os.path as osp
+            img = img_meta_list[0][0]['filename']
+            if not osp.exists(img):
+                img = imgs[0][:3, ...].permute(1, 2, 0) * 255
+                img = img.detach().numpy().astype(np.uint8)
+                ori_shape = img.shape[:2]
+            else:
+                ori_shape = LoadImage()({'img': img})['ori_shape']
+
+            # resize onnx_result to ori_shape
+            onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
+                                      (ori_shape[1], ori_shape[0]))
+            show_result_pyplot(
+                model,
+                img, (onnx_result_, ),
+                palette=model.PALETTE,
+                block=False,
+                title='ONNXRuntime',
+                opacity=0.5)
+
+            # resize pytorch_result to ori_shape
+            pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
+                                         (ori_shape[1], ori_shape[0]))
+            show_result_pyplot(
+                model,
+                img, (pytorch_result_, ),
+                title='PyTorch',
+                palette=model.PALETTE,
+                opacity=0.5)
+        # compare results
+        np.testing.assert_allclose(
+            pytorch_result.astype(np.float32) / num_classes,
+            onnx_result.astype(np.float32) / num_classes,
+            rtol=1e-5,
+            atol=1e-5,
+            err_msg='The outputs are different between Pytorch and ONNX')
+        print('The outputs are same between Pytorch and ONNX')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
+    parser.add_argument('config', help='test config file path')
+    parser.add_argument('--checkpoint', help='checkpoint file', default=None)
+    parser.add_argument(
+        '--input-img', type=str, help='Images for input', default=None)
+    parser.add_argument(
+        '--show',
+        action='store_true',
+        help='show onnx graph and segmentation results')
+    parser.add_argument(
+        '--verify', action='store_true', help='verify the onnx model')
+    parser.add_argument('--output-file', type=str, default='tmp.onnx')
+    parser.add_argument('--opset-version', type=int, default=11)
+    parser.add_argument(
+        '--shape',
+        type=int,
+        nargs='+',
+        default=None,
+        help='input image height and width.')
+    parser.add_argument(
+        '--rescale_shape',
+        type=int,
+        nargs='+',
+        default=None,
+        help='output image rescale height and width, work for slide mode.')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        help='Override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. If the value to '
+        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+        'Note that the quotation marks are necessary and that no white space '
+        'is allowed.')
+    parser.add_argument(
+        '--dynamic-export',
+        action='store_true',
+        help='Whether to export onnx with dynamic axis.')
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    cfg = mmcv.Config.fromfile(args.config)
+    if args.cfg_options is not None:
+        cfg.merge_from_dict(args.cfg_options)
+    cfg.model.pretrained = None
+
+    if args.shape is None:
+        img_scale = cfg.test_pipeline[1]['img_scale']
+        input_shape = (1, 3, img_scale[1], img_scale[0])
+    elif len(args.shape) == 1:
+        input_shape = (1, 3, args.shape[0], args.shape[0])
+    elif len(args.shape) == 2:
+        input_shape = (
+            1,
+            3,
+        ) + tuple(args.shape)
+    else:
+        raise ValueError('invalid input shape')
+
+    test_mode = cfg.model.test_cfg.mode
+
+    # build the model and load checkpoint
+    cfg.model.train_cfg = None
+    segmentor = build_segmentor(
+        cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
+    # convert SyncBN to BN
+    segmentor = _convert_batchnorm(segmentor)
+
+    if args.checkpoint:
+        checkpoint = load_checkpoint(
+            segmentor, args.checkpoint, map_location='cpu')
+        segmentor.CLASSES = checkpoint['meta']['CLASSES']
+        segmentor.PALETTE = checkpoint['meta']['PALETTE']
+
+    # read input or create dummpy input
+    if args.input_img is not None:
+        preprocess_shape = (input_shape[2], input_shape[3])
+        rescale_shape = None
+        if args.rescale_shape is not None:
+            rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
+        mm_inputs = _prepare_input_img(
+            args.input_img,
+            cfg.data.test.pipeline,
+            shape=preprocess_shape,
+            rescale_shape=rescale_shape)
+    else:
+        if isinstance(segmentor.decode_head, nn.ModuleList):
+            num_classes = segmentor.decode_head[-1].num_classes
+        else:
+            num_classes = segmentor.decode_head.num_classes
+        mm_inputs = _demo_mm_inputs(input_shape, num_classes)
+
+    # convert model to onnx file
+    pytorch2onnx(
+        segmentor,
+        mm_inputs,
+        opset_version=args.opset_version,
+        show=args.show,
+        output_file=args.output_file,
+        verify=args.verify,
+        dynamic_export=args.dynamic_export)