|
a |
|
b/ViTPose/tools/test.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import os |
|
|
4 |
import os.path as osp |
|
|
5 |
import warnings |
|
|
6 |
|
|
|
7 |
import mmcv |
|
|
8 |
import torch |
|
|
9 |
from mmcv import Config, DictAction |
|
|
10 |
from mmcv.cnn import fuse_conv_bn |
|
|
11 |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
|
|
12 |
from mmcv.runner import get_dist_info, init_dist, load_checkpoint |
|
|
13 |
|
|
|
14 |
from mmpose.apis import multi_gpu_test, single_gpu_test |
|
|
15 |
from mmpose.datasets import build_dataloader, build_dataset |
|
|
16 |
from mmpose.models import build_posenet |
|
|
17 |
from mmpose.utils import setup_multi_processes |
|
|
18 |
|
|
|
19 |
try: |
|
|
20 |
from mmcv.runner import wrap_fp16_model |
|
|
21 |
except ImportError: |
|
|
22 |
warnings.warn('auto_fp16 from mmpose will be deprecated from v0.15.0' |
|
|
23 |
'Please install mmcv>=1.1.4') |
|
|
24 |
from mmpose.core import wrap_fp16_model |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
def parse_args(): |
|
|
28 |
parser = argparse.ArgumentParser(description='mmpose test model') |
|
|
29 |
parser.add_argument('config', help='test config file path') |
|
|
30 |
parser.add_argument('checkpoint', help='checkpoint file') |
|
|
31 |
parser.add_argument('--out', help='output result file') |
|
|
32 |
parser.add_argument( |
|
|
33 |
'--work-dir', help='the dir to save evaluation results') |
|
|
34 |
parser.add_argument( |
|
|
35 |
'--fuse-conv-bn', |
|
|
36 |
action='store_true', |
|
|
37 |
help='Whether to fuse conv and bn, this will slightly increase' |
|
|
38 |
'the inference speed') |
|
|
39 |
parser.add_argument( |
|
|
40 |
'--gpu-id', |
|
|
41 |
type=int, |
|
|
42 |
default=0, |
|
|
43 |
help='id of gpu to use ' |
|
|
44 |
'(only applicable to non-distributed testing)') |
|
|
45 |
parser.add_argument( |
|
|
46 |
'--eval', |
|
|
47 |
default=None, |
|
|
48 |
nargs='+', |
|
|
49 |
help='evaluation metric, which depends on the dataset,' |
|
|
50 |
' e.g., "mAP" for MSCOCO') |
|
|
51 |
parser.add_argument( |
|
|
52 |
'--gpu_collect', |
|
|
53 |
action='store_true', |
|
|
54 |
help='whether to use gpu to collect results') |
|
|
55 |
parser.add_argument('--tmpdir', help='tmp dir for writing some results') |
|
|
56 |
parser.add_argument( |
|
|
57 |
'--cfg-options', |
|
|
58 |
nargs='+', |
|
|
59 |
action=DictAction, |
|
|
60 |
default={}, |
|
|
61 |
help='override some settings in the used config, the key-value pair ' |
|
|
62 |
'in xxx=yyy format will be merged into config file. For example, ' |
|
|
63 |
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") |
|
|
64 |
parser.add_argument( |
|
|
65 |
'--launcher', |
|
|
66 |
choices=['none', 'pytorch', 'slurm', 'mpi'], |
|
|
67 |
default='none', |
|
|
68 |
help='job launcher') |
|
|
69 |
parser.add_argument('--local_rank', type=int, default=0) |
|
|
70 |
args = parser.parse_args() |
|
|
71 |
if 'LOCAL_RANK' not in os.environ: |
|
|
72 |
os.environ['LOCAL_RANK'] = str(args.local_rank) |
|
|
73 |
return args |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
def merge_configs(cfg1, cfg2): |
|
|
77 |
# Merge cfg2 into cfg1 |
|
|
78 |
# Overwrite cfg1 if repeated, ignore if value is None. |
|
|
79 |
cfg1 = {} if cfg1 is None else cfg1.copy() |
|
|
80 |
cfg2 = {} if cfg2 is None else cfg2 |
|
|
81 |
for k, v in cfg2.items(): |
|
|
82 |
if v: |
|
|
83 |
cfg1[k] = v |
|
|
84 |
return cfg1 |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
def main(): |
|
|
88 |
args = parse_args() |
|
|
89 |
|
|
|
90 |
cfg = Config.fromfile(args.config) |
|
|
91 |
|
|
|
92 |
if args.cfg_options is not None: |
|
|
93 |
cfg.merge_from_dict(args.cfg_options) |
|
|
94 |
|
|
|
95 |
# set multi-process settings |
|
|
96 |
setup_multi_processes(cfg) |
|
|
97 |
|
|
|
98 |
# set cudnn_benchmark |
|
|
99 |
if cfg.get('cudnn_benchmark', False): |
|
|
100 |
torch.backends.cudnn.benchmark = True |
|
|
101 |
cfg.model.pretrained = None |
|
|
102 |
cfg.data.test.test_mode = True |
|
|
103 |
|
|
|
104 |
# work_dir is determined in this priority: CLI > segment in file > filename |
|
|
105 |
if args.work_dir is not None: |
|
|
106 |
# update configs according to CLI args if args.work_dir is not None |
|
|
107 |
cfg.work_dir = args.work_dir |
|
|
108 |
elif cfg.get('work_dir', None) is None: |
|
|
109 |
# use config filename as default work_dir if cfg.work_dir is None |
|
|
110 |
cfg.work_dir = osp.join('./work_dirs', |
|
|
111 |
osp.splitext(osp.basename(args.config))[0]) |
|
|
112 |
|
|
|
113 |
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) |
|
|
114 |
|
|
|
115 |
# init distributed env first, since logger depends on the dist info. |
|
|
116 |
if args.launcher == 'none': |
|
|
117 |
distributed = False |
|
|
118 |
else: |
|
|
119 |
distributed = True |
|
|
120 |
init_dist(args.launcher, **cfg.dist_params) |
|
|
121 |
|
|
|
122 |
# build the dataloader |
|
|
123 |
dataset = build_dataset(cfg.data.test, dict(test_mode=True)) |
|
|
124 |
# step 1: give default values and override (if exist) from cfg.data |
|
|
125 |
loader_cfg = { |
|
|
126 |
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed), |
|
|
127 |
**({} if torch.__version__ != 'parrots' else dict( |
|
|
128 |
prefetch_num=2, |
|
|
129 |
pin_memory=False, |
|
|
130 |
)), |
|
|
131 |
**dict((k, cfg.data[k]) for k in [ |
|
|
132 |
'seed', |
|
|
133 |
'prefetch_num', |
|
|
134 |
'pin_memory', |
|
|
135 |
'persistent_workers', |
|
|
136 |
] if k in cfg.data) |
|
|
137 |
} |
|
|
138 |
# step2: cfg.data.test_dataloader has higher priority |
|
|
139 |
test_loader_cfg = { |
|
|
140 |
**loader_cfg, |
|
|
141 |
**dict(shuffle=False, drop_last=False), |
|
|
142 |
**dict(workers_per_gpu=cfg.data.get('workers_per_gpu', 1)), |
|
|
143 |
**dict(samples_per_gpu=cfg.data.get('samples_per_gpu', 1)), |
|
|
144 |
**cfg.data.get('test_dataloader', {}) |
|
|
145 |
} |
|
|
146 |
data_loader = build_dataloader(dataset, **test_loader_cfg) |
|
|
147 |
|
|
|
148 |
# build the model and load checkpoint |
|
|
149 |
model = build_posenet(cfg.model) |
|
|
150 |
fp16_cfg = cfg.get('fp16', None) |
|
|
151 |
if fp16_cfg is not None: |
|
|
152 |
wrap_fp16_model(model) |
|
|
153 |
load_checkpoint(model, args.checkpoint, map_location='cpu') |
|
|
154 |
|
|
|
155 |
if args.fuse_conv_bn: |
|
|
156 |
model = fuse_conv_bn(model) |
|
|
157 |
|
|
|
158 |
if not distributed: |
|
|
159 |
model = MMDataParallel(model, device_ids=[args.gpu_id]) |
|
|
160 |
outputs = single_gpu_test(model, data_loader) |
|
|
161 |
else: |
|
|
162 |
model = MMDistributedDataParallel( |
|
|
163 |
model.cuda(), |
|
|
164 |
device_ids=[torch.cuda.current_device()], |
|
|
165 |
broadcast_buffers=False) |
|
|
166 |
outputs = multi_gpu_test(model, data_loader, args.tmpdir, |
|
|
167 |
args.gpu_collect) |
|
|
168 |
|
|
|
169 |
rank, _ = get_dist_info() |
|
|
170 |
eval_config = cfg.get('evaluation', {}) |
|
|
171 |
eval_config = merge_configs(eval_config, dict(metric=args.eval)) |
|
|
172 |
|
|
|
173 |
if rank == 0: |
|
|
174 |
if args.out: |
|
|
175 |
print(f'\nwriting results to {args.out}') |
|
|
176 |
mmcv.dump(outputs, args.out) |
|
|
177 |
|
|
|
178 |
results = dataset.evaluate(outputs, cfg.work_dir, **eval_config) |
|
|
179 |
for k, v in sorted(results.items()): |
|
|
180 |
print(f'{k}: {v}') |
|
|
181 |
|
|
|
182 |
|
|
|
183 |
if __name__ == '__main__': |
|
|
184 |
main() |