|
a |
|
b/tools/convert_datasets/pascal_context.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import os.path as osp |
|
|
4 |
from functools import partial |
|
|
5 |
|
|
|
6 |
import mmcv |
|
|
7 |
import numpy as np |
|
|
8 |
from detail import Detail |
|
|
9 |
from PIL import Image |
|
|
10 |
|
|
|
11 |
_mapping = np.sort( |
|
|
12 |
np.array([ |
|
|
13 |
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, |
|
|
14 |
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, |
|
|
15 |
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, |
|
|
16 |
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 |
|
|
17 |
])) |
|
|
18 |
_key = np.array(range(len(_mapping))).astype('uint8') |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def generate_labels(img_id, detail, out_dir): |
|
|
22 |
|
|
|
23 |
def _class_to_index(mask, _mapping, _key): |
|
|
24 |
# assert the values |
|
|
25 |
values = np.unique(mask) |
|
|
26 |
for i in range(len(values)): |
|
|
27 |
assert (values[i] in _mapping) |
|
|
28 |
index = np.digitize(mask.ravel(), _mapping, right=True) |
|
|
29 |
return _key[index].reshape(mask.shape) |
|
|
30 |
|
|
|
31 |
mask = Image.fromarray( |
|
|
32 |
_class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) |
|
|
33 |
filename = img_id['file_name'] |
|
|
34 |
mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) |
|
|
35 |
return osp.splitext(osp.basename(filename))[0] |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def parse_args(): |
|
|
39 |
parser = argparse.ArgumentParser( |
|
|
40 |
description='Convert PASCAL VOC annotations to mmsegmentation format') |
|
|
41 |
parser.add_argument('devkit_path', help='pascal voc devkit path') |
|
|
42 |
parser.add_argument('json_path', help='annoation json filepath') |
|
|
43 |
parser.add_argument('-o', '--out_dir', help='output path') |
|
|
44 |
args = parser.parse_args() |
|
|
45 |
return args |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
def main(): |
|
|
49 |
args = parse_args() |
|
|
50 |
devkit_path = args.devkit_path |
|
|
51 |
if args.out_dir is None: |
|
|
52 |
out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') |
|
|
53 |
else: |
|
|
54 |
out_dir = args.out_dir |
|
|
55 |
json_path = args.json_path |
|
|
56 |
mmcv.mkdir_or_exist(out_dir) |
|
|
57 |
img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') |
|
|
58 |
|
|
|
59 |
train_detail = Detail(json_path, img_dir, 'train') |
|
|
60 |
train_ids = train_detail.getImgs() |
|
|
61 |
|
|
|
62 |
val_detail = Detail(json_path, img_dir, 'val') |
|
|
63 |
val_ids = val_detail.getImgs() |
|
|
64 |
|
|
|
65 |
mmcv.mkdir_or_exist( |
|
|
66 |
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) |
|
|
67 |
|
|
|
68 |
train_list = mmcv.track_progress( |
|
|
69 |
partial(generate_labels, detail=train_detail, out_dir=out_dir), |
|
|
70 |
train_ids) |
|
|
71 |
with open( |
|
|
72 |
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', |
|
|
73 |
'train.txt'), 'w') as f: |
|
|
74 |
f.writelines(line + '\n' for line in sorted(train_list)) |
|
|
75 |
|
|
|
76 |
val_list = mmcv.track_progress( |
|
|
77 |
partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) |
|
|
78 |
with open( |
|
|
79 |
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', |
|
|
80 |
'val.txt'), 'w') as f: |
|
|
81 |
f.writelines(line + '\n' for line in sorted(val_list)) |
|
|
82 |
|
|
|
83 |
print('Done!') |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
if __name__ == '__main__': |
|
|
87 |
main() |