Switch to side-by-side view

--- a
+++ b/.dev/benchmark_inference.py
@@ -0,0 +1,149 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import hashlib
+import logging
+import os
+import os.path as osp
+import warnings
+from argparse import ArgumentParser
+
+import requests
+from mmcv import Config
+
+from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
+from mmseg.utils import get_root_logger
+
+# ignore warnings when segmentors inference
+warnings.filterwarnings('ignore')
+
+
+def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
+    """Download checkpoint and check if hash code is true."""
+    url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}'  # noqa
+
+    r = requests.get(url)
+    assert r.status_code != 403, f'{url} Access denied.'
+
+    with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
+        code.write(r.content)
+
+    true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
+
+    # check hash code
+    with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
+        sha256_cal = hashlib.sha256()
+        sha256_cal.update(fp.read())
+        cur_hash_code = sha256_cal.hexdigest()[:8]
+
+    assert true_hash_code == cur_hash_code, f'{url} download failed, '
+    'incomplete downloaded file or url invalid.'
+
+    if cur_hash_code != true_hash_code:
+        os.remove(osp.join(collect_dir, checkpoint_name))
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument('config', help='test config file path')
+    parser.add_argument('checkpoint_root', help='Checkpoint file root path')
+    parser.add_argument(
+        '-i', '--img', default='demo/demo.png', help='Image file')
+    parser.add_argument('-a', '--aug', action='store_true', help='aug test')
+    parser.add_argument('-m', '--model-name', help='model name to inference')
+    parser.add_argument(
+        '-s', '--show', action='store_true', help='show results')
+    parser.add_argument(
+        '-d', '--device', default='cuda:0', help='Device used for inference')
+    args = parser.parse_args()
+    return args
+
+
+def inference_model(config_name, checkpoint, args, logger=None):
+    cfg = Config.fromfile(config_name)
+    if args.aug:
+        if 'flip' in cfg.data.test.pipeline[
+                1] and 'img_scale' in cfg.data.test.pipeline[1]:
+            cfg.data.test.pipeline[1].img_ratios = [
+                0.5, 0.75, 1.0, 1.25, 1.5, 1.75
+            ]
+            cfg.data.test.pipeline[1].flip = True
+        else:
+            if logger is not None:
+                logger.error(f'{config_name}: unable to start aug test')
+            else:
+                print(f'{config_name}: unable to start aug test', flush=True)
+
+    model = init_segmentor(cfg, checkpoint, device=args.device)
+    # test a single image
+    result = inference_segmentor(model, args.img)
+
+    # show the results
+    if args.show:
+        show_result_pyplot(model, args.img, result)
+    return result
+
+
+# Sample test whether the inference code is correct
+def main(args):
+    config = Config.fromfile(args.config)
+
+    if not os.path.exists(args.checkpoint_root):
+        os.makedirs(args.checkpoint_root, 0o775)
+
+    # test single model
+    if args.model_name:
+        if args.model_name in config:
+            model_infos = config[args.model_name]
+            if not isinstance(model_infos, list):
+                model_infos = [model_infos]
+            for model_info in model_infos:
+                config_name = model_info['config'].strip()
+                print(f'processing: {config_name}', flush=True)
+                checkpoint = osp.join(args.checkpoint_root,
+                                      model_info['checkpoint'].strip())
+                try:
+                    # build the model from a config file and a checkpoint file
+                    inference_model(config_name, checkpoint, args)
+                except Exception:
+                    print(f'{config_name} test failed!')
+                    continue
+                return
+        else:
+            raise RuntimeError('model name input error.')
+
+    # test all model
+    logger = get_root_logger(
+        log_file='benchmark_inference_image.log', log_level=logging.ERROR)
+
+    for model_name in config:
+        model_infos = config[model_name]
+
+        if not isinstance(model_infos, list):
+            model_infos = [model_infos]
+        for model_info in model_infos:
+            print('processing: ', model_info['config'], flush=True)
+            config_path = model_info['config'].strip()
+            config_name = osp.splitext(osp.basename(config_path))[0]
+            checkpoint_name = model_info['checkpoint'].strip()
+            checkpoint = osp.join(args.checkpoint_root, checkpoint_name)
+
+            # ensure checkpoint exists
+            try:
+                if not osp.exists(checkpoint):
+                    download_checkpoint(checkpoint_name, model_name,
+                                        config_name.rstrip('.py'),
+                                        args.checkpoint_root)
+            except Exception:
+                logger.error(f'{checkpoint_name} download error')
+                continue
+
+            # test model inference with checkpoint
+            try:
+                # build the model from a config file and a checkpoint file
+                inference_model(config_path, checkpoint, args, logger)
+            except Exception as e:
+                logger.error(f'{config_path} " : {repr(e)}')
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    main(args)