a b/tools/convert_datasets/cityscapes.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os.path as osp
4
5
import mmcv
6
from cityscapesscripts.preparation.json2labelImg import json2labelImg
7
8
9
def convert_json_to_label(json_file):
10
    label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
11
    json2labelImg(json_file, label_file, 'trainIds')
12
13
14
def parse_args():
15
    parser = argparse.ArgumentParser(
16
        description='Convert Cityscapes annotations to TrainIds')
17
    parser.add_argument('cityscapes_path', help='cityscapes data path')
18
    parser.add_argument('--gt-dir', default='gtFine', type=str)
19
    parser.add_argument('-o', '--out-dir', help='output path')
20
    parser.add_argument(
21
        '--nproc', default=1, type=int, help='number of process')
22
    args = parser.parse_args()
23
    return args
24
25
26
def main():
27
    args = parse_args()
28
    cityscapes_path = args.cityscapes_path
29
    out_dir = args.out_dir if args.out_dir else cityscapes_path
30
    mmcv.mkdir_or_exist(out_dir)
31
32
    gt_dir = osp.join(cityscapes_path, args.gt_dir)
33
34
    poly_files = []
35
    for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True):
36
        poly_file = osp.join(gt_dir, poly)
37
        poly_files.append(poly_file)
38
    if args.nproc > 1:
39
        mmcv.track_parallel_progress(convert_json_to_label, poly_files,
40
                                     args.nproc)
41
    else:
42
        mmcv.track_progress(convert_json_to_label, poly_files)
43
44
    split_names = ['train', 'val', 'test']
45
46
    for split in split_names:
47
        filenames = []
48
        for poly in mmcv.scandir(
49
                osp.join(gt_dir, split), '_polygons.json', recursive=True):
50
            filenames.append(poly.replace('_gtFine_polygons.json', ''))
51
        with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
52
            f.writelines(f + '\n' for f in filenames)
53
54
55
if __name__ == '__main__':
56
    main()