# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import warnings
import numpy as np
import torch
from mmpose.apis import init_pose_model
try:
import onnx
import onnxruntime as rt
except ImportError as e:
raise ImportError(f'Please install onnx and onnxruntime first. {e}')
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=1.0.4')
def _convert_batchnorm(module):
"""Convert the syncBNs into normal BN3ds."""
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm3d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def pytorch2onnx(model,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
"""Convert pytorch model to onnx model.
Args:
model (:obj:`nn.Module`): The pytorch model to be exported.
input_shape (tuple[int]): The input tensor shape of the model.
opset_version (int): Opset version of onnx used. Default: 11.
show (bool): Determines whether to print the onnx model architecture.
Default: False.
output_file (str): Output onnx model name. Default: 'tmp.onnx'.
verify (bool): Determines whether to verify the onnx model.
Default: False.
"""
model.cpu().eval()
one_img = torch.randn(input_shape)
register_extra_symbolics(opset_version)
torch.onnx.export(
model,
one_img,
output_file,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_results = model(one_img)
if not isinstance(pytorch_results, (list, tuple)):
assert isinstance(pytorch_results, torch.Tensor)
pytorch_results = [pytorch_results]
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert len(net_feed_input) == 1
sess = rt.InferenceSession(output_file)
onnx_results = sess.run(None,
{net_feed_input[0]: one_img.detach().numpy()})
# compare results
assert len(pytorch_results) == len(onnx_results)
for pt_result, onnx_result in zip(pytorch_results, onnx_results):
assert np.allclose(
pt_result.detach().cpu(), onnx_result, atol=1.e-5
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMPose models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1, 3, 256, 192],
help='input size')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
assert args.opset_version == 11, 'MMPose only supports opset 11 now'
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
model = init_pose_model(args.config, args.checkpoint, device='cpu')
model = _convert_batchnorm(model)
# onnx.export does not support kwargs
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'Please implement the forward method for exporting.')
# convert model to onnx file
pytorch2onnx(
model,
args.shape,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)