Diff of /tools/onnx2tensorrt.py [000000] .. [4e96d3]

Switch to unified view

a b/tools/onnx2tensorrt.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os
4
import os.path as osp
5
from typing import Iterable, Optional, Union
6
7
import matplotlib.pyplot as plt
8
import mmcv
9
import numpy as np
10
import onnxruntime as ort
11
import torch
12
from mmcv.ops import get_onnxruntime_op_path
13
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
14
                           save_trt_engine)
15
16
from mmseg.apis.inference import LoadImage
17
from mmseg.datasets import DATASETS
18
from mmseg.datasets.pipelines import Compose
19
20
21
def get_GiB(x: int):
22
    """return x GiB."""
23
    return x * (1 << 30)
24
25
26
def _prepare_input_img(img_path: str,
27
                       test_pipeline: Iterable[dict],
28
                       shape: Optional[Iterable] = None,
29
                       rescale_shape: Optional[Iterable] = None) -> dict:
30
    # build the data pipeline
31
    if shape is not None:
32
        test_pipeline[1]['img_scale'] = (shape[1], shape[0])
33
    test_pipeline[1]['transforms'][0]['keep_ratio'] = False
34
    test_pipeline = [LoadImage()] + test_pipeline[1:]
35
    test_pipeline = Compose(test_pipeline)
36
    # prepare data
37
    data = dict(img=img_path)
38
    data = test_pipeline(data)
39
    imgs = data['img']
40
    img_metas = [i.data for i in data['img_metas']]
41
42
    if rescale_shape is not None:
43
        for img_meta in img_metas:
44
            img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
45
46
    mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
47
48
    return mm_inputs
49
50
51
def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
52
    # update img and its meta list
53
    N = img_list[0].size(0)
54
    img_meta = img_meta_list[0][0]
55
    img_shape = img_meta['img_shape']
56
    ori_shape = img_meta['ori_shape']
57
    pad_shape = img_meta['pad_shape']
58
    new_img_meta_list = [[{
59
        'img_shape':
60
        img_shape,
61
        'ori_shape':
62
        ori_shape,
63
        'pad_shape':
64
        pad_shape,
65
        'filename':
66
        img_meta['filename'],
67
        'scale_factor':
68
        (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
69
        'flip':
70
        False,
71
    } for _ in range(N)]]
72
73
    return img_list, new_img_meta_list
74
75
76
def show_result_pyplot(img: Union[str, np.ndarray],
77
                       result: np.ndarray,
78
                       palette: Optional[Iterable] = None,
79
                       fig_size: Iterable[int] = (15, 10),
80
                       opacity: float = 0.5,
81
                       title: str = '',
82
                       block: bool = True):
83
    img = mmcv.imread(img)
84
    img = img.copy()
85
    seg = result[0]
86
    seg = mmcv.imresize(seg, img.shape[:2][::-1])
87
    palette = np.array(palette)
88
    assert palette.shape[1] == 3
89
    assert len(palette.shape) == 2
90
    assert 0 < opacity <= 1.0
91
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
92
    for label, color in enumerate(palette):
93
        color_seg[seg == label, :] = color
94
    # convert to BGR
95
    color_seg = color_seg[..., ::-1]
96
97
    img = img * (1 - opacity) + color_seg * opacity
98
    img = img.astype(np.uint8)
99
100
    plt.figure(figsize=fig_size)
101
    plt.imshow(mmcv.bgr2rgb(img))
102
    plt.title(title)
103
    plt.tight_layout()
104
    plt.show(block=block)
105
106
107
def onnx2tensorrt(onnx_file: str,
108
                  trt_file: str,
109
                  config: dict,
110
                  input_config: dict,
111
                  fp16: bool = False,
112
                  verify: bool = False,
113
                  show: bool = False,
114
                  dataset: str = 'CityscapesDataset',
115
                  workspace_size: int = 1,
116
                  verbose: bool = False):
117
    import tensorrt as trt
118
    min_shape = input_config['min_shape']
119
    max_shape = input_config['max_shape']
120
    # create trt engine and wrapper
121
    opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
122
    max_workspace_size = get_GiB(workspace_size)
123
    trt_engine = onnx2trt(
124
        onnx_file,
125
        opt_shape_dict,
126
        log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
127
        fp16_mode=fp16,
128
        max_workspace_size=max_workspace_size)
129
    save_dir, _ = osp.split(trt_file)
130
    if save_dir:
131
        os.makedirs(save_dir, exist_ok=True)
132
    save_trt_engine(trt_engine, trt_file)
133
    print(f'Successfully created TensorRT engine: {trt_file}')
134
135
    if verify:
136
        inputs = _prepare_input_img(
137
            input_config['input_path'],
138
            config.data.test.pipeline,
139
            shape=min_shape[2:])
140
141
        imgs = inputs['imgs']
142
        img_metas = inputs['img_metas']
143
        img_list = [img[None, :] for img in imgs]
144
        img_meta_list = [[img_meta] for img_meta in img_metas]
145
        # update img_meta
146
        img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
147
148
        if max_shape[0] > 1:
149
            # concate flip image for batch test
150
            flip_img_list = [_.flip(-1) for _ in img_list]
151
            img_list = [
152
                torch.cat((ori_img, flip_img), 0)
153
                for ori_img, flip_img in zip(img_list, flip_img_list)
154
            ]
155
156
        # Get results from ONNXRuntime
157
        ort_custom_op_path = get_onnxruntime_op_path()
158
        session_options = ort.SessionOptions()
159
        if osp.exists(ort_custom_op_path):
160
            session_options.register_custom_ops_library(ort_custom_op_path)
161
        sess = ort.InferenceSession(onnx_file, session_options)
162
        sess.set_providers(['CPUExecutionProvider'], [{}])  # use cpu mode
163
        onnx_output = sess.run(['output'],
164
                               {'input': img_list[0].detach().numpy()})[0][0]
165
166
        # Get results from TensorRT
167
        trt_model = TRTWraper(trt_file, ['input'], ['output'])
168
        with torch.no_grad():
169
            trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
170
        trt_output = trt_outputs['output'][0].cpu().detach().numpy()
171
172
        if show:
173
            dataset = DATASETS.get(dataset)
174
            assert dataset is not None
175
            palette = dataset.PALETTE
176
177
            show_result_pyplot(
178
                input_config['input_path'],
179
                (onnx_output[0].astype(np.uint8), ),
180
                palette=palette,
181
                title='ONNXRuntime',
182
                block=False)
183
            show_result_pyplot(
184
                input_config['input_path'], (trt_output[0].astype(np.uint8), ),
185
                palette=palette,
186
                title='TensorRT')
187
188
        np.testing.assert_allclose(
189
            onnx_output, trt_output, rtol=1e-03, atol=1e-05)
190
        print('TensorRT and ONNXRuntime output all close.')
191
192
193
def parse_args():
194
    parser = argparse.ArgumentParser(
195
        description='Convert MMSegmentation models from ONNX to TensorRT')
196
    parser.add_argument('config', help='Config file of the model')
197
    parser.add_argument('model', help='Path to the input ONNX model')
198
    parser.add_argument(
199
        '--trt-file', type=str, help='Path to the output TensorRT engine')
200
    parser.add_argument(
201
        '--max-shape',
202
        type=int,
203
        nargs=4,
204
        default=[1, 3, 400, 600],
205
        help='Maximum shape of model input.')
206
    parser.add_argument(
207
        '--min-shape',
208
        type=int,
209
        nargs=4,
210
        default=[1, 3, 400, 600],
211
        help='Minimum shape of model input.')
212
    parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
213
    parser.add_argument(
214
        '--workspace-size',
215
        type=int,
216
        default=1,
217
        help='Max workspace size in GiB')
218
    parser.add_argument(
219
        '--input-img', type=str, default='', help='Image for test')
220
    parser.add_argument(
221
        '--show', action='store_true', help='Whether to show output results')
222
    parser.add_argument(
223
        '--dataset',
224
        type=str,
225
        default='CityscapesDataset',
226
        help='Dataset name')
227
    parser.add_argument(
228
        '--verify',
229
        action='store_true',
230
        help='Verify the outputs of ONNXRuntime and TensorRT')
231
    parser.add_argument(
232
        '--verbose',
233
        action='store_true',
234
        help='Whether to verbose logging messages while creating \
235
                TensorRT engine.')
236
    args = parser.parse_args()
237
    return args
238
239
240
if __name__ == '__main__':
241
242
    assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
243
    args = parse_args()
244
245
    if not args.input_img:
246
        args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
247
248
    # check arguments
249
    assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
250
    assert osp.exists(args.model), \
251
        'ONNX model {} not found.'.format(args.model)
252
    assert args.workspace_size >= 0, 'Workspace size less than 0.'
253
    assert DATASETS.get(args.dataset) is not None, \
254
        'Dataset {} does not found.'.format(args.dataset)
255
    for max_value, min_value in zip(args.max_shape, args.min_shape):
256
        assert max_value >= min_value, \
257
            'max_shape should be larger than min shape'
258
259
    input_config = {
260
        'min_shape': args.min_shape,
261
        'max_shape': args.max_shape,
262
        'input_path': args.input_img
263
    }
264
265
    cfg = mmcv.Config.fromfile(args.config)
266
    onnx2tensorrt(
267
        args.model,
268
        args.trt_file,
269
        cfg,
270
        input_config,
271
        fp16=args.fp16,
272
        verify=args.verify,
273
        show=args.show,
274
        dataset=args.dataset,
275
        workspace_size=args.workspace_size,
276
        verbose=args.verbose)