--- a +++ b/tools/deployment/pytorch2onnx.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import mmcv +import numpy as np +import torch +from mmcv.runner import load_checkpoint + +from mmaction.models import build_model + +try: + import onnx + import onnxruntime as rt +except ImportError as e: + raise ImportError(f'Please install onnx and onnxruntime first. {e}') + +try: + from mmcv.onnx.symbolic import register_extra_symbolics +except ModuleNotFoundError: + raise NotImplementedError('please update mmcv to version>=1.0.4') + + +def _convert_batchnorm(module): + """Convert the syncBNs into normal BN3ds.""" + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm3d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def pytorch2onnx(model, + input_shape, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False): + """Convert pytorch model to onnx model. + + Args: + model (:obj:`nn.Module`): The pytorch model to be exported. + input_shape (tuple[int]): The input tensor shape of the model. + opset_version (int): Opset version of onnx used. Default: 11. + show (bool): Determines whether to print the onnx model architecture. + Default: False. + output_file (str): Output onnx model name. Default: 'tmp.onnx'. + verify (bool): Determines whether to verify the onnx model. + Default: False. + """ + model.cpu().eval() + + input_tensor = torch.randn(input_shape) + + register_extra_symbolics(opset_version) + torch.onnx.export( + model, + input_tensor, + output_file, + export_params=True, + keep_initializers_as_inputs=True, + verbose=show, + opset_version=opset_version) + + print(f'Successfully exported ONNX model: {output_file}') + if verify: + # check by onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # check the numerical value + # get pytorch output + pytorch_result = model(input_tensor)[0].detach().numpy() + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert len(net_feed_input) == 1 + sess = rt.InferenceSession(output_file) + onnx_result = sess.run( + None, {net_feed_input[0]: input_tensor.detach().numpy()})[0] + # only compare part of results + random_class = np.random.randint(pytorch_result.shape[1]) + assert np.allclose( + pytorch_result[:, random_class], onnx_result[:, random_class] + ), 'The outputs are different between Pytorch and ONNX' + print('The numerical values are same between Pytorch and ONNX') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert MMAction2 models to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) + parser.add_argument( + '--verify', + action='store_true', + help='verify the onnx model output against pytorch output') + parser.add_argument( + '--is-localizer', + action='store_true', + help='whether it is a localizer') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[1, 3, 8, 224, 224], + help='input video size') + parser.add_argument( + '--softmax', + action='store_true', + help='wheter to add softmax layer at the end of recognizers') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + assert args.opset_version == 11, 'MMAction2 only supports opset 11 now' + + cfg = mmcv.Config.fromfile(args.config) + # import modules from string list. + + if not args.is_localizer: + cfg.model.backbone.pretrained = None + + # build the model + model = build_model( + cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + model = _convert_batchnorm(model) + + # onnx.export does not support kwargs + if hasattr(model, 'forward_dummy'): + from functools import partial + model.forward = partial(model.forward_dummy, softmax=args.softmax) + elif hasattr(model, '_forward') and args.is_localizer: + model.forward = model._forward + else: + raise NotImplementedError( + 'Please implement the forward method for exporting.') + + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + + # convert model to onnx file + pytorch2onnx( + model, + args.shape, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify)