a b/ViTPose/tools/analysis/benchmark_inference.py
1
#!/usr/bin/env bash
2
# Copyright (c) OpenMMLab. All rights reserved.
3
import argparse
4
import time
5
6
import torch
7
from mmcv import Config
8
from mmcv.cnn import fuse_conv_bn
9
from mmcv.parallel import MMDataParallel
10
from mmcv.runner.fp16_utils import wrap_fp16_model
11
12
from mmpose.datasets import build_dataloader, build_dataset
13
from mmpose.models import build_posenet
14
15
16
def parse_args():
17
    parser = argparse.ArgumentParser(
18
        description='MMPose benchmark a recognizer')
19
    parser.add_argument('config', help='test config file path')
20
    parser.add_argument(
21
        '--log-interval', default=10, help='interval of logging')
22
    parser.add_argument(
23
        '--fuse-conv-bn',
24
        action='store_true',
25
        help='Whether to fuse conv and bn, this will slightly increase'
26
        'the inference speed')
27
    args = parser.parse_args()
28
    return args
29
30
31
def main():
32
    args = parse_args()
33
34
    cfg = Config.fromfile(args.config)
35
    # set cudnn_benchmark
36
    if cfg.get('cudnn_benchmark', False):
37
        torch.backends.cudnn.benchmark = True
38
39
    # build the dataloader
40
    dataset = build_dataset(cfg.data.val)
41
    data_loader = build_dataloader(
42
        dataset,
43
        samples_per_gpu=1,
44
        workers_per_gpu=cfg.data.workers_per_gpu,
45
        dist=False,
46
        shuffle=False)
47
48
    # build the model and load checkpoint
49
    model = build_posenet(cfg.model)
50
    fp16_cfg = cfg.get('fp16', None)
51
    if fp16_cfg is not None:
52
        wrap_fp16_model(model)
53
    if args.fuse_conv_bn:
54
        model = fuse_conv_bn(model)
55
    model = MMDataParallel(model, device_ids=[0])
56
57
    # the first several iterations may be very slow so skip them
58
    num_warmup = 5
59
    pure_inf_time = 0
60
61
    # benchmark with total batch and take the average
62
    for i, data in enumerate(data_loader):
63
64
        torch.cuda.synchronize()
65
        start_time = time.perf_counter()
66
        with torch.no_grad():
67
            model(return_loss=False, **data)
68
69
        torch.cuda.synchronize()
70
        elapsed = time.perf_counter() - start_time
71
72
        if i >= num_warmup:
73
            pure_inf_time += elapsed
74
            if (i + 1) % args.log_interval == 0:
75
                its = (i + 1 - num_warmup) / pure_inf_time
76
                print(f'Done item [{i + 1:<3}],  {its:.2f} items / s')
77
    print(f'Overall average: {its:.2f} items / s')
78
    print(f'Total time: {pure_inf_time:.2f} s')
79
80
81
if __name__ == '__main__':
82
    main()