[36ab12]: / ViTPose / tests / test_datasets / test_dataset_wrapper.py

Download this file

68 lines (57 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv import Config
from mmpose.datasets.builder import build_dataset
def test_concat_dataset():
# build COCO-like dataset config
dataset_info = Config.fromfile(
'configs/_base_/datasets/coco.py').dataset_info
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=True,
det_bbox_thr=0.0,
bbox_file='tests/data/coco/test_coco_det_AP_H_56.json',
)
dataset_cfg = dict(
type='TopDownCocoDataset',
ann_file='tests/data/coco/test_coco.json',
img_prefix='tests/data/coco/',
data_cfg=data_cfg,
pipeline=[],
dataset_info=dataset_info)
dataset = build_dataset(dataset_cfg)
# Case 1: build ConcatDataset explicitly
concat_dataset_cfg = dict(
type='ConcatDataset', datasets=[dataset_cfg, dataset_cfg])
concat_dataset = build_dataset(concat_dataset_cfg)
assert len(concat_dataset) == 2 * len(dataset)
# Case 2: build ConcatDataset from cfg sequence
concat_dataset = build_dataset([dataset_cfg, dataset_cfg])
assert len(concat_dataset) == 2 * len(dataset)
# Case 3: build ConcatDataset from ann_file sequence
concat_dataset_cfg = dataset_cfg.copy()
for key in ['ann_file', 'type', 'img_prefix', 'dataset_info']:
val = concat_dataset_cfg[key]
concat_dataset_cfg[key] = [val] * 2
for key in ['num_joints', 'dataset_channel']:
val = concat_dataset_cfg['data_cfg'][key]
concat_dataset_cfg['data_cfg'][key] = [val] * 2
concat_dataset = build_dataset(concat_dataset_cfg)
assert len(concat_dataset) == 2 * len(dataset)