|
a |
|
b/tools/pytorch2onnx.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
from functools import partial |
|
|
4 |
|
|
|
5 |
import mmcv |
|
|
6 |
import numpy as np |
|
|
7 |
import onnxruntime as rt |
|
|
8 |
import torch |
|
|
9 |
import torch._C |
|
|
10 |
import torch.serialization |
|
|
11 |
from mmcv import DictAction |
|
|
12 |
from mmcv.onnx import register_extra_symbolics |
|
|
13 |
from mmcv.runner import load_checkpoint |
|
|
14 |
from torch import nn |
|
|
15 |
|
|
|
16 |
from mmseg.apis import show_result_pyplot |
|
|
17 |
from mmseg.apis.inference import LoadImage |
|
|
18 |
from mmseg.datasets.pipelines import Compose |
|
|
19 |
from mmseg.models import build_segmentor |
|
|
20 |
from mmseg.ops import resize |
|
|
21 |
|
|
|
22 |
torch.manual_seed(3) |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def _convert_batchnorm(module): |
|
|
26 |
module_output = module |
|
|
27 |
if isinstance(module, torch.nn.SyncBatchNorm): |
|
|
28 |
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, |
|
|
29 |
module.momentum, module.affine, |
|
|
30 |
module.track_running_stats) |
|
|
31 |
if module.affine: |
|
|
32 |
module_output.weight.data = module.weight.data.clone().detach() |
|
|
33 |
module_output.bias.data = module.bias.data.clone().detach() |
|
|
34 |
# keep requires_grad unchanged |
|
|
35 |
module_output.weight.requires_grad = module.weight.requires_grad |
|
|
36 |
module_output.bias.requires_grad = module.bias.requires_grad |
|
|
37 |
module_output.running_mean = module.running_mean |
|
|
38 |
module_output.running_var = module.running_var |
|
|
39 |
module_output.num_batches_tracked = module.num_batches_tracked |
|
|
40 |
for name, child in module.named_children(): |
|
|
41 |
module_output.add_module(name, _convert_batchnorm(child)) |
|
|
42 |
del module |
|
|
43 |
return module_output |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
def _demo_mm_inputs(input_shape, num_classes): |
|
|
47 |
"""Create a superset of inputs needed to run test or train batches. |
|
|
48 |
|
|
|
49 |
Args: |
|
|
50 |
input_shape (tuple): |
|
|
51 |
input batch dimensions |
|
|
52 |
num_classes (int): |
|
|
53 |
number of semantic classes |
|
|
54 |
""" |
|
|
55 |
(N, C, H, W) = input_shape |
|
|
56 |
rng = np.random.RandomState(0) |
|
|
57 |
imgs = rng.rand(*input_shape) |
|
|
58 |
segs = rng.randint( |
|
|
59 |
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) |
|
|
60 |
img_metas = [{ |
|
|
61 |
'img_shape': (H, W, C), |
|
|
62 |
'ori_shape': (H, W, C), |
|
|
63 |
'pad_shape': (H, W, C), |
|
|
64 |
'filename': '<demo>.png', |
|
|
65 |
'scale_factor': 1.0, |
|
|
66 |
'flip': False, |
|
|
67 |
} for _ in range(N)] |
|
|
68 |
mm_inputs = { |
|
|
69 |
'imgs': torch.FloatTensor(imgs).requires_grad_(True), |
|
|
70 |
'img_metas': img_metas, |
|
|
71 |
'gt_semantic_seg': torch.LongTensor(segs) |
|
|
72 |
} |
|
|
73 |
return mm_inputs |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
def _prepare_input_img(img_path, |
|
|
77 |
test_pipeline, |
|
|
78 |
shape=None, |
|
|
79 |
rescale_shape=None): |
|
|
80 |
# build the data pipeline |
|
|
81 |
if shape is not None: |
|
|
82 |
test_pipeline[1]['img_scale'] = (shape[1], shape[0]) |
|
|
83 |
test_pipeline[1]['transforms'][0]['keep_ratio'] = False |
|
|
84 |
test_pipeline = [LoadImage()] + test_pipeline[1:] |
|
|
85 |
test_pipeline = Compose(test_pipeline) |
|
|
86 |
# prepare data |
|
|
87 |
data = dict(img=img_path) |
|
|
88 |
data = test_pipeline(data) |
|
|
89 |
imgs = data['img'] |
|
|
90 |
img_metas = [i.data for i in data['img_metas']] |
|
|
91 |
|
|
|
92 |
if rescale_shape is not None: |
|
|
93 |
for img_meta in img_metas: |
|
|
94 |
img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) |
|
|
95 |
|
|
|
96 |
mm_inputs = {'imgs': imgs, 'img_metas': img_metas} |
|
|
97 |
|
|
|
98 |
return mm_inputs |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
def _update_input_img(img_list, img_meta_list, update_ori_shape=False): |
|
|
102 |
# update img and its meta list |
|
|
103 |
N, C, H, W = img_list[0].shape |
|
|
104 |
img_meta = img_meta_list[0][0] |
|
|
105 |
img_shape = (H, W, C) |
|
|
106 |
if update_ori_shape: |
|
|
107 |
ori_shape = img_shape |
|
|
108 |
else: |
|
|
109 |
ori_shape = img_meta['ori_shape'] |
|
|
110 |
pad_shape = img_shape |
|
|
111 |
new_img_meta_list = [[{ |
|
|
112 |
'img_shape': |
|
|
113 |
img_shape, |
|
|
114 |
'ori_shape': |
|
|
115 |
ori_shape, |
|
|
116 |
'pad_shape': |
|
|
117 |
pad_shape, |
|
|
118 |
'filename': |
|
|
119 |
img_meta['filename'], |
|
|
120 |
'scale_factor': |
|
|
121 |
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, |
|
|
122 |
'flip': |
|
|
123 |
False, |
|
|
124 |
} for _ in range(N)]] |
|
|
125 |
|
|
|
126 |
return img_list, new_img_meta_list |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def pytorch2onnx(model, |
|
|
130 |
mm_inputs, |
|
|
131 |
opset_version=11, |
|
|
132 |
show=False, |
|
|
133 |
output_file='tmp.onnx', |
|
|
134 |
verify=False, |
|
|
135 |
dynamic_export=False): |
|
|
136 |
"""Export Pytorch model to ONNX model and verify the outputs are same |
|
|
137 |
between Pytorch and ONNX. |
|
|
138 |
|
|
|
139 |
Args: |
|
|
140 |
model (nn.Module): Pytorch model we want to export. |
|
|
141 |
mm_inputs (dict): Contain the input tensors and img_metas information. |
|
|
142 |
opset_version (int): The onnx op version. Default: 11. |
|
|
143 |
show (bool): Whether print the computation graph. Default: False. |
|
|
144 |
output_file (string): The path to where we store the output ONNX model. |
|
|
145 |
Default: `tmp.onnx`. |
|
|
146 |
verify (bool): Whether compare the outputs between Pytorch and ONNX. |
|
|
147 |
Default: False. |
|
|
148 |
dynamic_export (bool): Whether to export ONNX with dynamic axis. |
|
|
149 |
Default: False. |
|
|
150 |
""" |
|
|
151 |
model.cpu().eval() |
|
|
152 |
test_mode = model.test_cfg.mode |
|
|
153 |
|
|
|
154 |
if isinstance(model.decode_head, nn.ModuleList): |
|
|
155 |
num_classes = model.decode_head[-1].num_classes |
|
|
156 |
else: |
|
|
157 |
num_classes = model.decode_head.num_classes |
|
|
158 |
|
|
|
159 |
imgs = mm_inputs.pop('imgs') |
|
|
160 |
img_metas = mm_inputs.pop('img_metas') |
|
|
161 |
|
|
|
162 |
img_list = [img[None, :] for img in imgs] |
|
|
163 |
img_meta_list = [[img_meta] for img_meta in img_metas] |
|
|
164 |
# update img_meta |
|
|
165 |
img_list, img_meta_list = _update_input_img(img_list, img_meta_list) |
|
|
166 |
|
|
|
167 |
# replace original forward function |
|
|
168 |
origin_forward = model.forward |
|
|
169 |
model.forward = partial( |
|
|
170 |
model.forward, |
|
|
171 |
img_metas=img_meta_list, |
|
|
172 |
return_loss=False, |
|
|
173 |
rescale=True) |
|
|
174 |
dynamic_axes = None |
|
|
175 |
if dynamic_export: |
|
|
176 |
if test_mode == 'slide': |
|
|
177 |
dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}} |
|
|
178 |
else: |
|
|
179 |
dynamic_axes = { |
|
|
180 |
'input': { |
|
|
181 |
0: 'batch', |
|
|
182 |
2: 'height', |
|
|
183 |
3: 'width' |
|
|
184 |
}, |
|
|
185 |
'output': { |
|
|
186 |
1: 'batch', |
|
|
187 |
2: 'height', |
|
|
188 |
3: 'width' |
|
|
189 |
} |
|
|
190 |
} |
|
|
191 |
|
|
|
192 |
register_extra_symbolics(opset_version) |
|
|
193 |
with torch.no_grad(): |
|
|
194 |
torch.onnx.export( |
|
|
195 |
model, (img_list, ), |
|
|
196 |
output_file, |
|
|
197 |
input_names=['input'], |
|
|
198 |
output_names=['output'], |
|
|
199 |
export_params=True, |
|
|
200 |
keep_initializers_as_inputs=False, |
|
|
201 |
verbose=show, |
|
|
202 |
opset_version=opset_version, |
|
|
203 |
dynamic_axes=dynamic_axes) |
|
|
204 |
print(f'Successfully exported ONNX model: {output_file}') |
|
|
205 |
model.forward = origin_forward |
|
|
206 |
|
|
|
207 |
if verify: |
|
|
208 |
# check by onnx |
|
|
209 |
import onnx |
|
|
210 |
onnx_model = onnx.load(output_file) |
|
|
211 |
onnx.checker.check_model(onnx_model) |
|
|
212 |
|
|
|
213 |
if dynamic_export and test_mode == 'whole': |
|
|
214 |
# scale image for dynamic shape test |
|
|
215 |
img_list = [resize(_, scale_factor=1.5) for _ in img_list] |
|
|
216 |
# concate flip image for batch test |
|
|
217 |
flip_img_list = [_.flip(-1) for _ in img_list] |
|
|
218 |
img_list = [ |
|
|
219 |
torch.cat((ori_img, flip_img), 0) |
|
|
220 |
for ori_img, flip_img in zip(img_list, flip_img_list) |
|
|
221 |
] |
|
|
222 |
|
|
|
223 |
# update img_meta |
|
|
224 |
img_list, img_meta_list = _update_input_img( |
|
|
225 |
img_list, img_meta_list, test_mode == 'whole') |
|
|
226 |
|
|
|
227 |
# check the numerical value |
|
|
228 |
# get pytorch output |
|
|
229 |
with torch.no_grad(): |
|
|
230 |
pytorch_result = model(img_list, img_meta_list, return_loss=False) |
|
|
231 |
pytorch_result = np.stack(pytorch_result, 0) |
|
|
232 |
|
|
|
233 |
# get onnx output |
|
|
234 |
input_all = [node.name for node in onnx_model.graph.input] |
|
|
235 |
input_initializer = [ |
|
|
236 |
node.name for node in onnx_model.graph.initializer |
|
|
237 |
] |
|
|
238 |
net_feed_input = list(set(input_all) - set(input_initializer)) |
|
|
239 |
assert (len(net_feed_input) == 1) |
|
|
240 |
sess = rt.InferenceSession(output_file) |
|
|
241 |
onnx_result = sess.run( |
|
|
242 |
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] |
|
|
243 |
# show segmentation results |
|
|
244 |
if show: |
|
|
245 |
import cv2 |
|
|
246 |
import os.path as osp |
|
|
247 |
img = img_meta_list[0][0]['filename'] |
|
|
248 |
if not osp.exists(img): |
|
|
249 |
img = imgs[0][:3, ...].permute(1, 2, 0) * 255 |
|
|
250 |
img = img.detach().numpy().astype(np.uint8) |
|
|
251 |
ori_shape = img.shape[:2] |
|
|
252 |
else: |
|
|
253 |
ori_shape = LoadImage()({'img': img})['ori_shape'] |
|
|
254 |
|
|
|
255 |
# resize onnx_result to ori_shape |
|
|
256 |
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), |
|
|
257 |
(ori_shape[1], ori_shape[0])) |
|
|
258 |
show_result_pyplot( |
|
|
259 |
model, |
|
|
260 |
img, (onnx_result_, ), |
|
|
261 |
palette=model.PALETTE, |
|
|
262 |
block=False, |
|
|
263 |
title='ONNXRuntime', |
|
|
264 |
opacity=0.5) |
|
|
265 |
|
|
|
266 |
# resize pytorch_result to ori_shape |
|
|
267 |
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), |
|
|
268 |
(ori_shape[1], ori_shape[0])) |
|
|
269 |
show_result_pyplot( |
|
|
270 |
model, |
|
|
271 |
img, (pytorch_result_, ), |
|
|
272 |
title='PyTorch', |
|
|
273 |
palette=model.PALETTE, |
|
|
274 |
opacity=0.5) |
|
|
275 |
# compare results |
|
|
276 |
np.testing.assert_allclose( |
|
|
277 |
pytorch_result.astype(np.float32) / num_classes, |
|
|
278 |
onnx_result.astype(np.float32) / num_classes, |
|
|
279 |
rtol=1e-5, |
|
|
280 |
atol=1e-5, |
|
|
281 |
err_msg='The outputs are different between Pytorch and ONNX') |
|
|
282 |
print('The outputs are same between Pytorch and ONNX') |
|
|
283 |
|
|
|
284 |
|
|
|
285 |
def parse_args(): |
|
|
286 |
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') |
|
|
287 |
parser.add_argument('config', help='test config file path') |
|
|
288 |
parser.add_argument('--checkpoint', help='checkpoint file', default=None) |
|
|
289 |
parser.add_argument( |
|
|
290 |
'--input-img', type=str, help='Images for input', default=None) |
|
|
291 |
parser.add_argument( |
|
|
292 |
'--show', |
|
|
293 |
action='store_true', |
|
|
294 |
help='show onnx graph and segmentation results') |
|
|
295 |
parser.add_argument( |
|
|
296 |
'--verify', action='store_true', help='verify the onnx model') |
|
|
297 |
parser.add_argument('--output-file', type=str, default='tmp.onnx') |
|
|
298 |
parser.add_argument('--opset-version', type=int, default=11) |
|
|
299 |
parser.add_argument( |
|
|
300 |
'--shape', |
|
|
301 |
type=int, |
|
|
302 |
nargs='+', |
|
|
303 |
default=None, |
|
|
304 |
help='input image height and width.') |
|
|
305 |
parser.add_argument( |
|
|
306 |
'--rescale_shape', |
|
|
307 |
type=int, |
|
|
308 |
nargs='+', |
|
|
309 |
default=None, |
|
|
310 |
help='output image rescale height and width, work for slide mode.') |
|
|
311 |
parser.add_argument( |
|
|
312 |
'--cfg-options', |
|
|
313 |
nargs='+', |
|
|
314 |
action=DictAction, |
|
|
315 |
help='Override some settings in the used config, the key-value pair ' |
|
|
316 |
'in xxx=yyy format will be merged into config file. If the value to ' |
|
|
317 |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' |
|
|
318 |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' |
|
|
319 |
'Note that the quotation marks are necessary and that no white space ' |
|
|
320 |
'is allowed.') |
|
|
321 |
parser.add_argument( |
|
|
322 |
'--dynamic-export', |
|
|
323 |
action='store_true', |
|
|
324 |
help='Whether to export onnx with dynamic axis.') |
|
|
325 |
args = parser.parse_args() |
|
|
326 |
return args |
|
|
327 |
|
|
|
328 |
|
|
|
329 |
if __name__ == '__main__': |
|
|
330 |
args = parse_args() |
|
|
331 |
|
|
|
332 |
cfg = mmcv.Config.fromfile(args.config) |
|
|
333 |
if args.cfg_options is not None: |
|
|
334 |
cfg.merge_from_dict(args.cfg_options) |
|
|
335 |
cfg.model.pretrained = None |
|
|
336 |
|
|
|
337 |
if args.shape is None: |
|
|
338 |
img_scale = cfg.test_pipeline[1]['img_scale'] |
|
|
339 |
input_shape = (1, 3, img_scale[1], img_scale[0]) |
|
|
340 |
elif len(args.shape) == 1: |
|
|
341 |
input_shape = (1, 3, args.shape[0], args.shape[0]) |
|
|
342 |
elif len(args.shape) == 2: |
|
|
343 |
input_shape = ( |
|
|
344 |
1, |
|
|
345 |
3, |
|
|
346 |
) + tuple(args.shape) |
|
|
347 |
else: |
|
|
348 |
raise ValueError('invalid input shape') |
|
|
349 |
|
|
|
350 |
test_mode = cfg.model.test_cfg.mode |
|
|
351 |
|
|
|
352 |
# build the model and load checkpoint |
|
|
353 |
cfg.model.train_cfg = None |
|
|
354 |
segmentor = build_segmentor( |
|
|
355 |
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) |
|
|
356 |
# convert SyncBN to BN |
|
|
357 |
segmentor = _convert_batchnorm(segmentor) |
|
|
358 |
|
|
|
359 |
if args.checkpoint: |
|
|
360 |
checkpoint = load_checkpoint( |
|
|
361 |
segmentor, args.checkpoint, map_location='cpu') |
|
|
362 |
segmentor.CLASSES = checkpoint['meta']['CLASSES'] |
|
|
363 |
segmentor.PALETTE = checkpoint['meta']['PALETTE'] |
|
|
364 |
|
|
|
365 |
# read input or create dummpy input |
|
|
366 |
if args.input_img is not None: |
|
|
367 |
preprocess_shape = (input_shape[2], input_shape[3]) |
|
|
368 |
rescale_shape = None |
|
|
369 |
if args.rescale_shape is not None: |
|
|
370 |
rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]] |
|
|
371 |
mm_inputs = _prepare_input_img( |
|
|
372 |
args.input_img, |
|
|
373 |
cfg.data.test.pipeline, |
|
|
374 |
shape=preprocess_shape, |
|
|
375 |
rescale_shape=rescale_shape) |
|
|
376 |
else: |
|
|
377 |
if isinstance(segmentor.decode_head, nn.ModuleList): |
|
|
378 |
num_classes = segmentor.decode_head[-1].num_classes |
|
|
379 |
else: |
|
|
380 |
num_classes = segmentor.decode_head.num_classes |
|
|
381 |
mm_inputs = _demo_mm_inputs(input_shape, num_classes) |
|
|
382 |
|
|
|
383 |
# convert model to onnx file |
|
|
384 |
pytorch2onnx( |
|
|
385 |
segmentor, |
|
|
386 |
mm_inputs, |
|
|
387 |
opset_version=args.opset_version, |
|
|
388 |
show=args.show, |
|
|
389 |
output_file=args.output_file, |
|
|
390 |
verify=args.verify, |
|
|
391 |
dynamic_export=args.dynamic_export) |