a b/.dev/benchmark_inference.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import hashlib
3
import logging
4
import os
5
import os.path as osp
6
import warnings
7
from argparse import ArgumentParser
8
9
import requests
10
from mmcv import Config
11
12
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
13
from mmseg.utils import get_root_logger
14
15
# ignore warnings when segmentors inference
16
warnings.filterwarnings('ignore')
17
18
19
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
20
    """Download checkpoint and check if hash code is true."""
21
    url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}'  # noqa
22
23
    r = requests.get(url)
24
    assert r.status_code != 403, f'{url} Access denied.'
25
26
    with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
27
        code.write(r.content)
28
29
    true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
30
31
    # check hash code
32
    with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
33
        sha256_cal = hashlib.sha256()
34
        sha256_cal.update(fp.read())
35
        cur_hash_code = sha256_cal.hexdigest()[:8]
36
37
    assert true_hash_code == cur_hash_code, f'{url} download failed, '
38
    'incomplete downloaded file or url invalid.'
39
40
    if cur_hash_code != true_hash_code:
41
        os.remove(osp.join(collect_dir, checkpoint_name))
42
43
44
def parse_args():
45
    parser = ArgumentParser()
46
    parser.add_argument('config', help='test config file path')
47
    parser.add_argument('checkpoint_root', help='Checkpoint file root path')
48
    parser.add_argument(
49
        '-i', '--img', default='demo/demo.png', help='Image file')
50
    parser.add_argument('-a', '--aug', action='store_true', help='aug test')
51
    parser.add_argument('-m', '--model-name', help='model name to inference')
52
    parser.add_argument(
53
        '-s', '--show', action='store_true', help='show results')
54
    parser.add_argument(
55
        '-d', '--device', default='cuda:0', help='Device used for inference')
56
    args = parser.parse_args()
57
    return args
58
59
60
def inference_model(config_name, checkpoint, args, logger=None):
61
    cfg = Config.fromfile(config_name)
62
    if args.aug:
63
        if 'flip' in cfg.data.test.pipeline[
64
                1] and 'img_scale' in cfg.data.test.pipeline[1]:
65
            cfg.data.test.pipeline[1].img_ratios = [
66
                0.5, 0.75, 1.0, 1.25, 1.5, 1.75
67
            ]
68
            cfg.data.test.pipeline[1].flip = True
69
        else:
70
            if logger is not None:
71
                logger.error(f'{config_name}: unable to start aug test')
72
            else:
73
                print(f'{config_name}: unable to start aug test', flush=True)
74
75
    model = init_segmentor(cfg, checkpoint, device=args.device)
76
    # test a single image
77
    result = inference_segmentor(model, args.img)
78
79
    # show the results
80
    if args.show:
81
        show_result_pyplot(model, args.img, result)
82
    return result
83
84
85
# Sample test whether the inference code is correct
86
def main(args):
87
    config = Config.fromfile(args.config)
88
89
    if not os.path.exists(args.checkpoint_root):
90
        os.makedirs(args.checkpoint_root, 0o775)
91
92
    # test single model
93
    if args.model_name:
94
        if args.model_name in config:
95
            model_infos = config[args.model_name]
96
            if not isinstance(model_infos, list):
97
                model_infos = [model_infos]
98
            for model_info in model_infos:
99
                config_name = model_info['config'].strip()
100
                print(f'processing: {config_name}', flush=True)
101
                checkpoint = osp.join(args.checkpoint_root,
102
                                      model_info['checkpoint'].strip())
103
                try:
104
                    # build the model from a config file and a checkpoint file
105
                    inference_model(config_name, checkpoint, args)
106
                except Exception:
107
                    print(f'{config_name} test failed!')
108
                    continue
109
                return
110
        else:
111
            raise RuntimeError('model name input error.')
112
113
    # test all model
114
    logger = get_root_logger(
115
        log_file='benchmark_inference_image.log', log_level=logging.ERROR)
116
117
    for model_name in config:
118
        model_infos = config[model_name]
119
120
        if not isinstance(model_infos, list):
121
            model_infos = [model_infos]
122
        for model_info in model_infos:
123
            print('processing: ', model_info['config'], flush=True)
124
            config_path = model_info['config'].strip()
125
            config_name = osp.splitext(osp.basename(config_path))[0]
126
            checkpoint_name = model_info['checkpoint'].strip()
127
            checkpoint = osp.join(args.checkpoint_root, checkpoint_name)
128
129
            # ensure checkpoint exists
130
            try:
131
                if not osp.exists(checkpoint):
132
                    download_checkpoint(checkpoint_name, model_name,
133
                                        config_name.rstrip('.py'),
134
                                        args.checkpoint_root)
135
            except Exception:
136
                logger.error(f'{checkpoint_name} download error')
137
                continue
138
139
            # test model inference with checkpoint
140
            try:
141
                # build the model from a config file and a checkpoint file
142
                inference_model(config_path, checkpoint, args, logger)
143
            except Exception as e:
144
                logger.error(f'{config_path} " : {repr(e)}')
145
146
147
if __name__ == '__main__':
148
    args = parse_args()
149
    main(args)