|
a |
|
b/.dev/benchmark_inference.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import hashlib |
|
|
3 |
import logging |
|
|
4 |
import os |
|
|
5 |
import os.path as osp |
|
|
6 |
import warnings |
|
|
7 |
from argparse import ArgumentParser |
|
|
8 |
|
|
|
9 |
import requests |
|
|
10 |
from mmcv import Config |
|
|
11 |
|
|
|
12 |
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot |
|
|
13 |
from mmseg.utils import get_root_logger |
|
|
14 |
|
|
|
15 |
# ignore warnings when segmentors inference |
|
|
16 |
warnings.filterwarnings('ignore') |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir): |
|
|
20 |
"""Download checkpoint and check if hash code is true.""" |
|
|
21 |
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa |
|
|
22 |
|
|
|
23 |
r = requests.get(url) |
|
|
24 |
assert r.status_code != 403, f'{url} Access denied.' |
|
|
25 |
|
|
|
26 |
with open(osp.join(collect_dir, checkpoint_name), 'wb') as code: |
|
|
27 |
code.write(r.content) |
|
|
28 |
|
|
|
29 |
true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1] |
|
|
30 |
|
|
|
31 |
# check hash code |
|
|
32 |
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp: |
|
|
33 |
sha256_cal = hashlib.sha256() |
|
|
34 |
sha256_cal.update(fp.read()) |
|
|
35 |
cur_hash_code = sha256_cal.hexdigest()[:8] |
|
|
36 |
|
|
|
37 |
assert true_hash_code == cur_hash_code, f'{url} download failed, ' |
|
|
38 |
'incomplete downloaded file or url invalid.' |
|
|
39 |
|
|
|
40 |
if cur_hash_code != true_hash_code: |
|
|
41 |
os.remove(osp.join(collect_dir, checkpoint_name)) |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def parse_args(): |
|
|
45 |
parser = ArgumentParser() |
|
|
46 |
parser.add_argument('config', help='test config file path') |
|
|
47 |
parser.add_argument('checkpoint_root', help='Checkpoint file root path') |
|
|
48 |
parser.add_argument( |
|
|
49 |
'-i', '--img', default='demo/demo.png', help='Image file') |
|
|
50 |
parser.add_argument('-a', '--aug', action='store_true', help='aug test') |
|
|
51 |
parser.add_argument('-m', '--model-name', help='model name to inference') |
|
|
52 |
parser.add_argument( |
|
|
53 |
'-s', '--show', action='store_true', help='show results') |
|
|
54 |
parser.add_argument( |
|
|
55 |
'-d', '--device', default='cuda:0', help='Device used for inference') |
|
|
56 |
args = parser.parse_args() |
|
|
57 |
return args |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def inference_model(config_name, checkpoint, args, logger=None): |
|
|
61 |
cfg = Config.fromfile(config_name) |
|
|
62 |
if args.aug: |
|
|
63 |
if 'flip' in cfg.data.test.pipeline[ |
|
|
64 |
1] and 'img_scale' in cfg.data.test.pipeline[1]: |
|
|
65 |
cfg.data.test.pipeline[1].img_ratios = [ |
|
|
66 |
0.5, 0.75, 1.0, 1.25, 1.5, 1.75 |
|
|
67 |
] |
|
|
68 |
cfg.data.test.pipeline[1].flip = True |
|
|
69 |
else: |
|
|
70 |
if logger is not None: |
|
|
71 |
logger.error(f'{config_name}: unable to start aug test') |
|
|
72 |
else: |
|
|
73 |
print(f'{config_name}: unable to start aug test', flush=True) |
|
|
74 |
|
|
|
75 |
model = init_segmentor(cfg, checkpoint, device=args.device) |
|
|
76 |
# test a single image |
|
|
77 |
result = inference_segmentor(model, args.img) |
|
|
78 |
|
|
|
79 |
# show the results |
|
|
80 |
if args.show: |
|
|
81 |
show_result_pyplot(model, args.img, result) |
|
|
82 |
return result |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
# Sample test whether the inference code is correct |
|
|
86 |
def main(args): |
|
|
87 |
config = Config.fromfile(args.config) |
|
|
88 |
|
|
|
89 |
if not os.path.exists(args.checkpoint_root): |
|
|
90 |
os.makedirs(args.checkpoint_root, 0o775) |
|
|
91 |
|
|
|
92 |
# test single model |
|
|
93 |
if args.model_name: |
|
|
94 |
if args.model_name in config: |
|
|
95 |
model_infos = config[args.model_name] |
|
|
96 |
if not isinstance(model_infos, list): |
|
|
97 |
model_infos = [model_infos] |
|
|
98 |
for model_info in model_infos: |
|
|
99 |
config_name = model_info['config'].strip() |
|
|
100 |
print(f'processing: {config_name}', flush=True) |
|
|
101 |
checkpoint = osp.join(args.checkpoint_root, |
|
|
102 |
model_info['checkpoint'].strip()) |
|
|
103 |
try: |
|
|
104 |
# build the model from a config file and a checkpoint file |
|
|
105 |
inference_model(config_name, checkpoint, args) |
|
|
106 |
except Exception: |
|
|
107 |
print(f'{config_name} test failed!') |
|
|
108 |
continue |
|
|
109 |
return |
|
|
110 |
else: |
|
|
111 |
raise RuntimeError('model name input error.') |
|
|
112 |
|
|
|
113 |
# test all model |
|
|
114 |
logger = get_root_logger( |
|
|
115 |
log_file='benchmark_inference_image.log', log_level=logging.ERROR) |
|
|
116 |
|
|
|
117 |
for model_name in config: |
|
|
118 |
model_infos = config[model_name] |
|
|
119 |
|
|
|
120 |
if not isinstance(model_infos, list): |
|
|
121 |
model_infos = [model_infos] |
|
|
122 |
for model_info in model_infos: |
|
|
123 |
print('processing: ', model_info['config'], flush=True) |
|
|
124 |
config_path = model_info['config'].strip() |
|
|
125 |
config_name = osp.splitext(osp.basename(config_path))[0] |
|
|
126 |
checkpoint_name = model_info['checkpoint'].strip() |
|
|
127 |
checkpoint = osp.join(args.checkpoint_root, checkpoint_name) |
|
|
128 |
|
|
|
129 |
# ensure checkpoint exists |
|
|
130 |
try: |
|
|
131 |
if not osp.exists(checkpoint): |
|
|
132 |
download_checkpoint(checkpoint_name, model_name, |
|
|
133 |
config_name.rstrip('.py'), |
|
|
134 |
args.checkpoint_root) |
|
|
135 |
except Exception: |
|
|
136 |
logger.error(f'{checkpoint_name} download error') |
|
|
137 |
continue |
|
|
138 |
|
|
|
139 |
# test model inference with checkpoint |
|
|
140 |
try: |
|
|
141 |
# build the model from a config file and a checkpoint file |
|
|
142 |
inference_model(config_path, checkpoint, args, logger) |
|
|
143 |
except Exception as e: |
|
|
144 |
logger.error(f'{config_path} " : {repr(e)}') |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
if __name__ == '__main__': |
|
|
148 |
args = parse_args() |
|
|
149 |
main(args) |