--- a +++ b/tools/pytorch2onnx.py @@ -0,0 +1,391 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from functools import partial + +import mmcv +import numpy as np +import onnxruntime as rt +import torch +import torch._C +import torch.serialization +from mmcv import DictAction +from mmcv.onnx import register_extra_symbolics +from mmcv.runner import load_checkpoint +from torch import nn + +from mmseg.apis import show_result_pyplot +from mmseg.apis.inference import LoadImage +from mmseg.datasets.pipelines import Compose +from mmseg.models import build_segmentor +from mmseg.ops import resize + +torch.manual_seed(3) + + +def _convert_batchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def _demo_mm_inputs(input_shape, num_classes): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + segs = rng.randint( + low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '<demo>.png', + 'scale_factor': 1.0, + 'flip': False, + } for _ in range(N)] + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas, + 'gt_semantic_seg': torch.LongTensor(segs) + } + return mm_inputs + + +def _prepare_input_img(img_path, + test_pipeline, + shape=None, + rescale_shape=None): + # build the data pipeline + if shape is not None: + test_pipeline[1]['img_scale'] = (shape[1], shape[0]) + test_pipeline[1]['transforms'][0]['keep_ratio'] = False + test_pipeline = [LoadImage()] + test_pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img_path) + data = test_pipeline(data) + imgs = data['img'] + img_metas = [i.data for i in data['img_metas']] + + if rescale_shape is not None: + for img_meta in img_metas: + img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) + + mm_inputs = {'imgs': imgs, 'img_metas': img_metas} + + return mm_inputs + + +def _update_input_img(img_list, img_meta_list, update_ori_shape=False): + # update img and its meta list + N, C, H, W = img_list[0].shape + img_meta = img_meta_list[0][0] + img_shape = (H, W, C) + if update_ori_shape: + ori_shape = img_shape + else: + ori_shape = img_meta['ori_shape'] + pad_shape = img_shape + new_img_meta_list = [[{ + 'img_shape': + img_shape, + 'ori_shape': + ori_shape, + 'pad_shape': + pad_shape, + 'filename': + img_meta['filename'], + 'scale_factor': + (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, + 'flip': + False, + } for _ in range(N)]] + + return img_list, new_img_meta_list + + +def pytorch2onnx(model, + mm_inputs, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False, + dynamic_export=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + mm_inputs (dict): Contain the input tensors and img_metas information. + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + dynamic_export (bool): Whether to export ONNX with dynamic axis. + Default: False. + """ + model.cpu().eval() + test_mode = model.test_cfg.mode + + if isinstance(model.decode_head, nn.ModuleList): + num_classes = model.decode_head[-1].num_classes + else: + num_classes = model.decode_head.num_classes + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + + img_list = [img[None, :] for img in imgs] + img_meta_list = [[img_meta] for img_meta in img_metas] + # update img_meta + img_list, img_meta_list = _update_input_img(img_list, img_meta_list) + + # replace original forward function + origin_forward = model.forward + model.forward = partial( + model.forward, + img_metas=img_meta_list, + return_loss=False, + rescale=True) + dynamic_axes = None + if dynamic_export: + if test_mode == 'slide': + dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}} + else: + dynamic_axes = { + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 1: 'batch', + 2: 'height', + 3: 'width' + } + } + + register_extra_symbolics(opset_version) + with torch.no_grad(): + torch.onnx.export( + model, (img_list, ), + output_file, + input_names=['input'], + output_names=['output'], + export_params=True, + keep_initializers_as_inputs=False, + verbose=show, + opset_version=opset_version, + dynamic_axes=dynamic_axes) + print(f'Successfully exported ONNX model: {output_file}') + model.forward = origin_forward + + if verify: + # check by onnx + import onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + if dynamic_export and test_mode == 'whole': + # scale image for dynamic shape test + img_list = [resize(_, scale_factor=1.5) for _ in img_list] + # concate flip image for batch test + flip_img_list = [_.flip(-1) for _ in img_list] + img_list = [ + torch.cat((ori_img, flip_img), 0) + for ori_img, flip_img in zip(img_list, flip_img_list) + ] + + # update img_meta + img_list, img_meta_list = _update_input_img( + img_list, img_meta_list, test_mode == 'whole') + + # check the numerical value + # get pytorch output + with torch.no_grad(): + pytorch_result = model(img_list, img_meta_list, return_loss=False) + pytorch_result = np.stack(pytorch_result, 0) + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + sess = rt.InferenceSession(output_file) + onnx_result = sess.run( + None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] + # show segmentation results + if show: + import cv2 + import os.path as osp + img = img_meta_list[0][0]['filename'] + if not osp.exists(img): + img = imgs[0][:3, ...].permute(1, 2, 0) * 255 + img = img.detach().numpy().astype(np.uint8) + ori_shape = img.shape[:2] + else: + ori_shape = LoadImage()({'img': img})['ori_shape'] + + # resize onnx_result to ori_shape + onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (onnx_result_, ), + palette=model.PALETTE, + block=False, + title='ONNXRuntime', + opacity=0.5) + + # resize pytorch_result to ori_shape + pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (pytorch_result_, ), + title='PyTorch', + palette=model.PALETTE, + opacity=0.5) + # compare results + np.testing.assert_allclose( + pytorch_result.astype(np.float32) / num_classes, + onnx_result.astype(np.float32) / num_classes, + rtol=1e-5, + atol=1e-5, + err_msg='The outputs are different between Pytorch and ONNX') + print('The outputs are same between Pytorch and ONNX') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument( + '--input-img', type=str, help='Images for input', default=None) + parser.add_argument( + '--show', + action='store_true', + help='show onnx graph and segmentation results') + parser.add_argument( + '--verify', action='store_true', help='verify the onnx model') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=None, + help='input image height and width.') + parser.add_argument( + '--rescale_shape', + type=int, + nargs='+', + default=None, + help='output image rescale height and width, work for slide mode.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--dynamic-export', + action='store_true', + help='Whether to export onnx with dynamic axis.') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + cfg.model.pretrained = None + + if args.shape is None: + img_scale = cfg.test_pipeline[1]['img_scale'] + input_shape = (1, 3, img_scale[1], img_scale[0]) + elif len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + test_mode = cfg.model.test_cfg.mode + + # build the model and load checkpoint + cfg.model.train_cfg = None + segmentor = build_segmentor( + cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + # convert SyncBN to BN + segmentor = _convert_batchnorm(segmentor) + + if args.checkpoint: + checkpoint = load_checkpoint( + segmentor, args.checkpoint, map_location='cpu') + segmentor.CLASSES = checkpoint['meta']['CLASSES'] + segmentor.PALETTE = checkpoint['meta']['PALETTE'] + + # read input or create dummpy input + if args.input_img is not None: + preprocess_shape = (input_shape[2], input_shape[3]) + rescale_shape = None + if args.rescale_shape is not None: + rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]] + mm_inputs = _prepare_input_img( + args.input_img, + cfg.data.test.pipeline, + shape=preprocess_shape, + rescale_shape=rescale_shape) + else: + if isinstance(segmentor.decode_head, nn.ModuleList): + num_classes = segmentor.decode_head[-1].num_classes + else: + num_classes = segmentor.decode_head.num_classes + mm_inputs = _demo_mm_inputs(input_shape, num_classes) + + # convert model to onnx file + pytorch2onnx( + segmentor, + mm_inputs, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify, + dynamic_export=args.dynamic_export)