Switch to unified view

a b/.dev/generate_benchmark_evaluation_script.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os.path as osp
4
5
from mmcv import Config
6
7
8
def parse_args():
9
    parser = argparse.ArgumentParser(
10
        description='Convert benchmark test model list to script')
11
    parser.add_argument('config', help='test config file path')
12
    parser.add_argument('--port', type=int, default=28171, help='dist port')
13
    parser.add_argument(
14
        '--work-dir',
15
        default='work_dirs/benchmark_evaluation',
16
        help='the dir to save metric')
17
    parser.add_argument(
18
        '--out',
19
        type=str,
20
        default='.dev/benchmark_evaluation.sh',
21
        help='path to save model benchmark script')
22
23
    args = parser.parse_args()
24
    return args
25
26
27
def process_model_info(model_info, work_dir):
28
    config = model_info['config'].strip()
29
    fname, _ = osp.splitext(osp.basename(config))
30
    job_name = fname
31
    checkpoint = model_info['checkpoint'].strip()
32
    work_dir = osp.join(work_dir, fname)
33
    if not isinstance(model_info['eval'], list):
34
        evals = [model_info['eval']]
35
    else:
36
        evals = model_info['eval']
37
    eval = ' '.join(evals)
38
    return dict(
39
        config=config,
40
        job_name=job_name,
41
        checkpoint=checkpoint,
42
        work_dir=work_dir,
43
        eval=eval)
44
45
46
def create_test_bash_info(commands, model_test_dict, port, script_name,
47
                          partition):
48
    config = model_test_dict['config']
49
    job_name = model_test_dict['job_name']
50
    checkpoint = model_test_dict['checkpoint']
51
    work_dir = model_test_dict['work_dir']
52
    eval = model_test_dict['eval']
53
54
    echo_info = f'\necho \'{config}\' &'
55
    commands.append(echo_info)
56
    commands.append('\n')
57
58
    command_info = f'GPUS=4  GPUS_PER_NODE=4  ' \
59
                   f'CPUS_PER_TASK=2 {script_name} '
60
61
    command_info += f'{partition} '
62
    command_info += f'{job_name} '
63
    command_info += f'{config} '
64
    command_info += f'$CHECKPOINT_DIR/{checkpoint} '
65
66
    command_info += f'--eval {eval} '
67
    command_info += f'--work-dir {work_dir} '
68
    command_info += f'--cfg-options dist_params.port={port} '
69
    command_info += '&'
70
71
    commands.append(command_info)
72
73
74
def main():
75
    args = parse_args()
76
    if args.out:
77
        out_suffix = args.out.split('.')[-1]
78
        assert args.out.endswith('.sh'), \
79
            f'Expected out file path suffix is .sh, but get .{out_suffix}'
80
81
    commands = []
82
    partition_name = 'PARTITION=$1'
83
    commands.append(partition_name)
84
    commands.append('\n')
85
86
    checkpoint_root = 'CHECKPOINT_DIR=$2'
87
    commands.append(checkpoint_root)
88
    commands.append('\n')
89
90
    script_name = osp.join('tools', 'slurm_test.sh')
91
    port = args.port
92
    work_dir = args.work_dir
93
94
    cfg = Config.fromfile(args.config)
95
96
    for model_key in cfg:
97
        model_infos = cfg[model_key]
98
        if not isinstance(model_infos, list):
99
            model_infos = [model_infos]
100
        for model_info in model_infos:
101
            print('processing: ', model_info['config'])
102
            model_test_dict = process_model_info(model_info, work_dir)
103
            create_test_bash_info(commands, model_test_dict, port, script_name,
104
                                  '$PARTITION')
105
            port += 1
106
107
    command_str = ''.join(commands)
108
    if args.out:
109
        with open(args.out, 'w') as f:
110
            f.write(command_str + '\n')
111
112
113
if __name__ == '__main__':
114
    main()