Diff of /tools/test.py [000000] .. [6d389a]

Switch to unified view

a b/tools/test.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os
4
import os.path as osp
5
import warnings
6
7
import mmcv
8
import torch
9
from mmcv import Config, DictAction
10
from mmcv.cnn import fuse_conv_bn
11
from mmcv.fileio.io import file_handlers
12
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
13
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
14
from mmcv.runner.fp16_utils import wrap_fp16_model
15
16
from mmaction.datasets import build_dataloader, build_dataset
17
from mmaction.models import build_model
18
from mmaction.utils import register_module_hooks
19
20
# TODO import test functions from mmcv and delete them from mmaction2
21
try:
22
    from mmcv.engine import multi_gpu_test, single_gpu_test
23
except (ImportError, ModuleNotFoundError):
24
    warnings.warn(
25
        'DeprecationWarning: single_gpu_test, multi_gpu_test, '
26
        'collect_results_cpu, collect_results_gpu from mmaction2 will be '
27
        'deprecated. Please install mmcv through master branch.')
28
    from mmaction.apis import multi_gpu_test, single_gpu_test
29
30
31
def parse_args():
32
    parser = argparse.ArgumentParser(
33
        description='MMAction2 test (and eval) a model')
34
    parser.add_argument('config', help='test config file path')
35
    parser.add_argument('checkpoint', help='checkpoint file')
36
    parser.add_argument(
37
        '--out',
38
        default=None,
39
        help='output result file in pkl/yaml/json format')
40
    parser.add_argument(
41
        '--fuse-conv-bn',
42
        action='store_true',
43
        help='Whether to fuse conv and bn, this will slightly increase'
44
        'the inference speed')
45
    parser.add_argument(
46
        '--eval',
47
        type=str,
48
        nargs='+',
49
        help='evaluation metrics, which depends on the dataset, e.g.,'
50
        ' "top_k_accuracy", "mean_class_accuracy" for video dataset')
51
    parser.add_argument(
52
        '--gpu-collect',
53
        action='store_true',
54
        help='whether to use gpu to collect results')
55
    parser.add_argument(
56
        '--tmpdir',
57
        help='tmp directory used for collecting results from multiple '
58
        'workers, available when gpu-collect is not specified')
59
    parser.add_argument(
60
        '--options',
61
        nargs='+',
62
        action=DictAction,
63
        default={},
64
        help='custom options for evaluation, the key-value pair in xxx=yyy '
65
        'format will be kwargs for dataset.evaluate() function (deprecate), '
66
        'change to --eval-options instead.')
67
    parser.add_argument(
68
        '--eval-options',
69
        nargs='+',
70
        action=DictAction,
71
        default={},
72
        help='custom options for evaluation, the key-value pair in xxx=yyy '
73
        'format will be kwargs for dataset.evaluate() function')
74
    parser.add_argument(
75
        '--cfg-options',
76
        nargs='+',
77
        action=DictAction,
78
        default={},
79
        help='override some settings in the used config, the key-value pair '
80
        'in xxx=yyy format will be merged into config file. For example, '
81
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
82
    parser.add_argument(
83
        '--average-clips',
84
        choices=['score', 'prob', None],
85
        default=None,
86
        help='average type when averaging test clips')
87
    parser.add_argument(
88
        '--launcher',
89
        choices=['none', 'pytorch', 'slurm', 'mpi'],
90
        default='none',
91
        help='job launcher')
92
    parser.add_argument('--local_rank', type=int, default=0)
93
    parser.add_argument(
94
        '--onnx',
95
        action='store_true',
96
        help='Whether to test with onnx model or not')
97
    parser.add_argument(
98
        '--tensorrt',
99
        action='store_true',
100
        help='Whether to test with TensorRT engine or not')
101
    args = parser.parse_args()
102
    if 'LOCAL_RANK' not in os.environ:
103
        os.environ['LOCAL_RANK'] = str(args.local_rank)
104
105
    if args.options and args.eval_options:
106
        raise ValueError(
107
            '--options and --eval-options cannot be both '
108
            'specified, --options is deprecated in favor of --eval-options')
109
    if args.options:
110
        warnings.warn('--options is deprecated in favor of --eval-options')
111
        args.eval_options = args.options
112
    return args
113
114
115
def turn_off_pretrained(cfg):
116
    # recursively find all pretrained in the model config,
117
    # and set them None to avoid redundant pretrain steps for testing
118
    if 'pretrained' in cfg:
119
        cfg.pretrained = None
120
121
    # recursively turn off pretrained value
122
    for sub_cfg in cfg.values():
123
        if isinstance(sub_cfg, dict):
124
            turn_off_pretrained(sub_cfg)
125
126
127
def inference_pytorch(args, cfg, distributed, data_loader):
128
    """Get predictions by pytorch models."""
129
    if args.average_clips is not None:
130
        # You can set average_clips during testing, it will override the
131
        # original setting
132
        if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
133
            cfg.model.setdefault('test_cfg',
134
                                 dict(average_clips=args.average_clips))
135
        else:
136
            if cfg.model.get('test_cfg') is not None:
137
                cfg.model.test_cfg.average_clips = args.average_clips
138
            else:
139
                cfg.test_cfg.average_clips = args.average_clips
140
141
    # remove redundant pretrain steps for testing
142
    turn_off_pretrained(cfg.model)
143
144
    # build the model and load checkpoint
145
    model = build_model(
146
        cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
147
148
    if len(cfg.module_hooks) > 0:
149
        register_module_hooks(model, cfg.module_hooks)
150
151
    fp16_cfg = cfg.get('fp16', None)
152
    if fp16_cfg is not None:
153
        wrap_fp16_model(model)
154
    load_checkpoint(model, args.checkpoint, map_location='cpu')
155
156
    if args.fuse_conv_bn:
157
        model = fuse_conv_bn(model)
158
159
    if not distributed:
160
        model = MMDataParallel(model, device_ids=[0])
161
        outputs = single_gpu_test(model, data_loader)
162
    else:
163
        model = MMDistributedDataParallel(
164
            model.cuda(),
165
            device_ids=[torch.cuda.current_device()],
166
            broadcast_buffers=False)
167
        outputs = multi_gpu_test(model, data_loader, args.tmpdir,
168
                                 args.gpu_collect)
169
170
    return outputs
171
172
173
def inference_tensorrt(ckpt_path, distributed, data_loader, batch_size):
174
    """Get predictions by TensorRT engine.
175
176
    For now, multi-gpu mode and dynamic tensor shape are not supported.
177
    """
178
    assert not distributed, \
179
        'TensorRT engine inference only supports single gpu mode.'
180
    import tensorrt as trt
181
    from mmcv.tensorrt.tensorrt_utils import (torch_dtype_from_trt,
182
                                              torch_device_from_trt)
183
184
    # load engine
185
    with trt.Logger() as logger, trt.Runtime(logger) as runtime:
186
        with open(ckpt_path, mode='rb') as f:
187
            engine_bytes = f.read()
188
        engine = runtime.deserialize_cuda_engine(engine_bytes)
189
190
    # For now, only support fixed input tensor
191
    cur_batch_size = engine.get_binding_shape(0)[0]
192
    assert batch_size == cur_batch_size, \
193
        ('Dataset and TensorRT model should share the same batch size, '
194
         f'but get {batch_size} and {cur_batch_size}')
195
196
    context = engine.create_execution_context()
197
198
    # get output tensor
199
    dtype = torch_dtype_from_trt(engine.get_binding_dtype(1))
200
    shape = tuple(context.get_binding_shape(1))
201
    device = torch_device_from_trt(engine.get_location(1))
202
    output = torch.empty(
203
        size=shape, dtype=dtype, device=device, requires_grad=False)
204
205
    # get predictions
206
    results = []
207
    dataset = data_loader.dataset
208
    prog_bar = mmcv.ProgressBar(len(dataset))
209
    for data in data_loader:
210
        bindings = [
211
            data['imgs'].contiguous().data_ptr(),
212
            output.contiguous().data_ptr()
213
        ]
214
        context.execute_async_v2(bindings,
215
                                 torch.cuda.current_stream().cuda_stream)
216
        results.extend(output.cpu().numpy())
217
        batch_size = len(next(iter(data.values())))
218
        for _ in range(batch_size):
219
            prog_bar.update()
220
    return results
221
222
223
def inference_onnx(ckpt_path, distributed, data_loader, batch_size):
224
    """Get predictions by ONNX.
225
226
    For now, multi-gpu mode and dynamic tensor shape are not supported.
227
    """
228
    assert not distributed, 'ONNX inference only supports single gpu mode.'
229
230
    import onnx
231
    import onnxruntime as rt
232
233
    # get input tensor name
234
    onnx_model = onnx.load(ckpt_path)
235
    input_all = [node.name for node in onnx_model.graph.input]
236
    input_initializer = [node.name for node in onnx_model.graph.initializer]
237
    net_feed_input = list(set(input_all) - set(input_initializer))
238
    assert len(net_feed_input) == 1
239
240
    # For now, only support fixed tensor shape
241
    input_tensor = None
242
    for tensor in onnx_model.graph.input:
243
        if tensor.name == net_feed_input[0]:
244
            input_tensor = tensor
245
            break
246
    cur_batch_size = input_tensor.type.tensor_type.shape.dim[0].dim_value
247
    assert batch_size == cur_batch_size, \
248
        ('Dataset and ONNX model should share the same batch size, '
249
         f'but get {batch_size} and {cur_batch_size}')
250
251
    # get predictions
252
    sess = rt.InferenceSession(ckpt_path)
253
    results = []
254
    dataset = data_loader.dataset
255
    prog_bar = mmcv.ProgressBar(len(dataset))
256
    for data in data_loader:
257
        imgs = data['imgs'].cpu().numpy()
258
        onnx_result = sess.run(None, {net_feed_input[0]: imgs})[0]
259
        results.extend(onnx_result)
260
        batch_size = len(next(iter(data.values())))
261
        for _ in range(batch_size):
262
            prog_bar.update()
263
    return results
264
265
266
def main():
267
    args = parse_args()
268
269
    if args.tensorrt and args.onnx:
270
        raise ValueError(
271
            'Cannot set onnx mode and tensorrt mode at the same time.')
272
273
    cfg = Config.fromfile(args.config)
274
275
    cfg.merge_from_dict(args.cfg_options)
276
277
    # Load output_config from cfg
278
    output_config = cfg.get('output_config', {})
279
    if args.out:
280
        # Overwrite output_config from args.out
281
        output_config = Config._merge_a_into_b(
282
            dict(out=args.out), output_config)
283
284
    # Load eval_config from cfg
285
    eval_config = cfg.get('eval_config', {})
286
    if args.eval:
287
        # Overwrite eval_config from args.eval
288
        eval_config = Config._merge_a_into_b(
289
            dict(metrics=args.eval), eval_config)
290
    if args.eval_options:
291
        # Add options from args.eval_options
292
        eval_config = Config._merge_a_into_b(args.eval_options, eval_config)
293
294
    assert output_config or eval_config, \
295
        ('Please specify at least one operation (save or eval the '
296
         'results) with the argument "--out" or "--eval"')
297
298
    dataset_type = cfg.data.test.type
299
    if output_config.get('out', None):
300
        if 'output_format' in output_config:
301
            # ugly workround to make recognition and localization the same
302
            warnings.warn(
303
                'Skip checking `output_format` in localization task.')
304
        else:
305
            out = output_config['out']
306
            # make sure the dirname of the output path exists
307
            mmcv.mkdir_or_exist(osp.dirname(out))
308
            _, suffix = osp.splitext(out)
309
            if dataset_type == 'AVADataset':
310
                assert suffix[1:] == 'csv', ('For AVADataset, the format of '
311
                                             'the output file should be csv')
312
            else:
313
                assert suffix[1:] in file_handlers, (
314
                    'The format of the output '
315
                    'file should be json, pickle or yaml')
316
317
    # set cudnn benchmark
318
    if cfg.get('cudnn_benchmark', False):
319
        torch.backends.cudnn.benchmark = True
320
    cfg.data.test.test_mode = True
321
322
    # init distributed env first, since logger depends on the dist info.
323
    if args.launcher == 'none':
324
        distributed = False
325
    else:
326
        distributed = True
327
        init_dist(args.launcher, **cfg.dist_params)
328
329
    # The flag is used to register module's hooks
330
    cfg.setdefault('module_hooks', [])
331
332
    # build the dataloader
333
    dataset = build_dataset(cfg.data.test, dict(test_mode=True))
334
    dataloader_setting = dict(
335
        videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
336
        workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
337
        dist=distributed,
338
        shuffle=False)
339
    dataloader_setting = dict(dataloader_setting,
340
                              **cfg.data.get('test_dataloader', {}))
341
    data_loader = build_dataloader(dataset, **dataloader_setting)
342
343
    if args.tensorrt:
344
        outputs = inference_tensorrt(args.checkpoint, distributed, data_loader,
345
                                     dataloader_setting['videos_per_gpu'])
346
    elif args.onnx:
347
        outputs = inference_onnx(args.checkpoint, distributed, data_loader,
348
                                 dataloader_setting['videos_per_gpu'])
349
    else:
350
        outputs = inference_pytorch(args, cfg, distributed, data_loader)
351
352
    rank, _ = get_dist_info()
353
    if rank == 0:
354
        if output_config.get('out', None):
355
            out = output_config['out']
356
            print(f'\nwriting results to {out}')
357
            dataset.dump_results(outputs, **output_config)
358
        if eval_config:
359
            eval_res = dataset.evaluate(outputs, **eval_config)
360
            for name, val in eval_res.items():
361
                print(f'{name}: {val:.04f}')
362
363
364
if __name__ == '__main__':
365
    main()