Switch to unified view

a b/tools/convert_datasets/pascal_context.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os.path as osp
4
from functools import partial
5
6
import mmcv
7
import numpy as np
8
from detail import Detail
9
from PIL import Image
10
11
_mapping = np.sort(
12
    np.array([
13
        0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
14
        158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
15
        440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
16
        85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
17
    ]))
18
_key = np.array(range(len(_mapping))).astype('uint8')
19
20
21
def generate_labels(img_id, detail, out_dir):
22
23
    def _class_to_index(mask, _mapping, _key):
24
        # assert the values
25
        values = np.unique(mask)
26
        for i in range(len(values)):
27
            assert (values[i] in _mapping)
28
        index = np.digitize(mask.ravel(), _mapping, right=True)
29
        return _key[index].reshape(mask.shape)
30
31
    mask = Image.fromarray(
32
        _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
33
    filename = img_id['file_name']
34
    mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
35
    return osp.splitext(osp.basename(filename))[0]
36
37
38
def parse_args():
39
    parser = argparse.ArgumentParser(
40
        description='Convert PASCAL VOC annotations to mmsegmentation format')
41
    parser.add_argument('devkit_path', help='pascal voc devkit path')
42
    parser.add_argument('json_path', help='annoation json filepath')
43
    parser.add_argument('-o', '--out_dir', help='output path')
44
    args = parser.parse_args()
45
    return args
46
47
48
def main():
49
    args = parse_args()
50
    devkit_path = args.devkit_path
51
    if args.out_dir is None:
52
        out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
53
    else:
54
        out_dir = args.out_dir
55
    json_path = args.json_path
56
    mmcv.mkdir_or_exist(out_dir)
57
    img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
58
59
    train_detail = Detail(json_path, img_dir, 'train')
60
    train_ids = train_detail.getImgs()
61
62
    val_detail = Detail(json_path, img_dir, 'val')
63
    val_ids = val_detail.getImgs()
64
65
    mmcv.mkdir_or_exist(
66
        osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
67
68
    train_list = mmcv.track_progress(
69
        partial(generate_labels, detail=train_detail, out_dir=out_dir),
70
        train_ids)
71
    with open(
72
            osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
73
                     'train.txt'), 'w') as f:
74
        f.writelines(line + '\n' for line in sorted(train_list))
75
76
    val_list = mmcv.track_progress(
77
        partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
78
    with open(
79
            osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
80
                     'val.txt'), 'w') as f:
81
        f.writelines(line + '\n' for line in sorted(val_list))
82
83
    print('Done!')
84
85
86
if __name__ == '__main__':
87
    main()