--- a +++ b/.dev/benchmark_inference.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import hashlib +import logging +import os +import os.path as osp +import warnings +from argparse import ArgumentParser + +import requests +from mmcv import Config + +from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot +from mmseg.utils import get_root_logger + +# ignore warnings when segmentors inference +warnings.filterwarnings('ignore') + + +def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir): + """Download checkpoint and check if hash code is true.""" + url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa + + r = requests.get(url) + assert r.status_code != 403, f'{url} Access denied.' + + with open(osp.join(collect_dir, checkpoint_name), 'wb') as code: + code.write(r.content) + + true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1] + + # check hash code + with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp: + sha256_cal = hashlib.sha256() + sha256_cal.update(fp.read()) + cur_hash_code = sha256_cal.hexdigest()[:8] + + assert true_hash_code == cur_hash_code, f'{url} download failed, ' + 'incomplete downloaded file or url invalid.' + + if cur_hash_code != true_hash_code: + os.remove(osp.join(collect_dir, checkpoint_name)) + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint_root', help='Checkpoint file root path') + parser.add_argument( + '-i', '--img', default='demo/demo.png', help='Image file') + parser.add_argument('-a', '--aug', action='store_true', help='aug test') + parser.add_argument('-m', '--model-name', help='model name to inference') + parser.add_argument( + '-s', '--show', action='store_true', help='show results') + parser.add_argument( + '-d', '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + return args + + +def inference_model(config_name, checkpoint, args, logger=None): + cfg = Config.fromfile(config_name) + if args.aug: + if 'flip' in cfg.data.test.pipeline[ + 1] and 'img_scale' in cfg.data.test.pipeline[1]: + cfg.data.test.pipeline[1].img_ratios = [ + 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 + ] + cfg.data.test.pipeline[1].flip = True + else: + if logger is not None: + logger.error(f'{config_name}: unable to start aug test') + else: + print(f'{config_name}: unable to start aug test', flush=True) + + model = init_segmentor(cfg, checkpoint, device=args.device) + # test a single image + result = inference_segmentor(model, args.img) + + # show the results + if args.show: + show_result_pyplot(model, args.img, result) + return result + + +# Sample test whether the inference code is correct +def main(args): + config = Config.fromfile(args.config) + + if not os.path.exists(args.checkpoint_root): + os.makedirs(args.checkpoint_root, 0o775) + + # test single model + if args.model_name: + if args.model_name in config: + model_infos = config[args.model_name] + if not isinstance(model_infos, list): + model_infos = [model_infos] + for model_info in model_infos: + config_name = model_info['config'].strip() + print(f'processing: {config_name}', flush=True) + checkpoint = osp.join(args.checkpoint_root, + model_info['checkpoint'].strip()) + try: + # build the model from a config file and a checkpoint file + inference_model(config_name, checkpoint, args) + except Exception: + print(f'{config_name} test failed!') + continue + return + else: + raise RuntimeError('model name input error.') + + # test all model + logger = get_root_logger( + log_file='benchmark_inference_image.log', log_level=logging.ERROR) + + for model_name in config: + model_infos = config[model_name] + + if not isinstance(model_infos, list): + model_infos = [model_infos] + for model_info in model_infos: + print('processing: ', model_info['config'], flush=True) + config_path = model_info['config'].strip() + config_name = osp.splitext(osp.basename(config_path))[0] + checkpoint_name = model_info['checkpoint'].strip() + checkpoint = osp.join(args.checkpoint_root, checkpoint_name) + + # ensure checkpoint exists + try: + if not osp.exists(checkpoint): + download_checkpoint(checkpoint_name, model_name, + config_name.rstrip('.py'), + args.checkpoint_root) + except Exception: + logger.error(f'{checkpoint_name} download error') + continue + + # test model inference with checkpoint + try: + # build the model from a config file and a checkpoint file + inference_model(config_path, checkpoint, args, logger) + except Exception as e: + logger.error(f'{config_path} " : {repr(e)}') + + +if __name__ == '__main__': + args = parse_args() + main(args)