Switch to unified view

a b/tools/analysis/benchmark.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import time
4
5
import torch
6
from mmcv import Config
7
from mmcv.cnn import fuse_conv_bn
8
from mmcv.parallel import MMDataParallel
9
from mmcv.runner.fp16_utils import wrap_fp16_model
10
11
from mmaction.datasets import build_dataloader, build_dataset
12
from mmaction.models import build_model
13
14
15
def parse_args():
16
    parser = argparse.ArgumentParser(
17
        description='MMAction2 benchmark a recognizer')
18
    parser.add_argument('config', help='test config file path')
19
    parser.add_argument(
20
        '--log-interval', default=10, help='interval of logging')
21
    parser.add_argument(
22
        '--fuse-conv-bn',
23
        action='store_true',
24
        help='Whether to fuse conv and bn, this will slightly increase'
25
        'the inference speed')
26
    args = parser.parse_args()
27
    return args
28
29
30
def main():
31
    args = parse_args()
32
33
    cfg = Config.fromfile(args.config)
34
    # set cudnn_benchmark
35
    if cfg.get('cudnn_benchmark', False):
36
        torch.backends.cudnn.benchmark = True
37
    cfg.model.backbone.pretrained = None
38
    cfg.data.test.test_mode = True
39
40
    # build the dataloader
41
    dataset = build_dataset(cfg.data.test, dict(test_mode=True))
42
    data_loader = build_dataloader(
43
        dataset,
44
        videos_per_gpu=1,
45
        workers_per_gpu=cfg.data.workers_per_gpu,
46
        persistent_workers=cfg.data.get('persistent_workers', False),
47
        dist=False,
48
        shuffle=False)
49
50
    # build the model and load checkpoint
51
    model = build_model(
52
        cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
53
    fp16_cfg = cfg.get('fp16', None)
54
    if fp16_cfg is not None:
55
        wrap_fp16_model(model)
56
    if args.fuse_conv_bn:
57
        model = fuse_conv_bn(model)
58
59
    model = MMDataParallel(model, device_ids=[0])
60
61
    model.eval()
62
63
    # the first several iterations may be very slow so skip them
64
    num_warmup = 5
65
    pure_inf_time = 0
66
67
    # benchmark with 2000 video and take the average
68
    for i, data in enumerate(data_loader):
69
70
        torch.cuda.synchronize()
71
        start_time = time.perf_counter()
72
73
        with torch.no_grad():
74
            model(return_loss=False, **data)
75
76
        torch.cuda.synchronize()
77
        elapsed = time.perf_counter() - start_time
78
79
        if i >= num_warmup:
80
            pure_inf_time += elapsed
81
            if (i + 1) % args.log_interval == 0:
82
                fps = (i + 1 - num_warmup) / pure_inf_time
83
                print(
84
                    f'Done video [{i + 1:<3}/ 2000], fps: {fps:.1f} video / s')
85
86
        if (i + 1) == 200:
87
            pure_inf_time += elapsed
88
            fps = (i + 1 - num_warmup) / pure_inf_time
89
            print(f'Overall fps: {fps:.1f} video / s')
90
            break
91
92
93
if __name__ == '__main__':
94
    main()