a b/.dev/generate_benchmark_train_script.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os.path as osp
4
5
# Default using 4 gpu when training
6
config_8gpu_list = [
7
    'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py',  # noqa
8
    'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py',
9
    'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py',
10
]
11
12
13
def parse_args():
14
    parser = argparse.ArgumentParser(
15
        description='Convert benchmark model json to script')
16
    parser.add_argument(
17
        'txt_path', type=str, help='txt path output by benchmark_filter')
18
    parser.add_argument('--port', type=int, default=24727, help='dist port')
19
    parser.add_argument(
20
        '--out',
21
        type=str,
22
        default='.dev/benchmark_train.sh',
23
        help='path to save model benchmark script')
24
25
    args = parser.parse_args()
26
    return args
27
28
29
def create_train_bash_info(commands, config, script_name, partition, port):
30
    cfg = config.strip()
31
32
    # print cfg name
33
    echo_info = f'echo \'{cfg}\' &'
34
    commands.append(echo_info)
35
    commands.append('\n')
36
37
    _, model_name = osp.split(osp.dirname(cfg))
38
    config_name, _ = osp.splitext(osp.basename(cfg))
39
    # default setting
40
    if cfg in config_8gpu_list:
41
        command_info = f'GPUS=8  GPUS_PER_NODE=8  ' \
42
                        f'CPUS_PER_TASK=2 {script_name} '
43
    else:
44
        command_info = f'GPUS=4  GPUS_PER_NODE=4  ' \
45
                        f'CPUS_PER_TASK=2 {script_name} '
46
    command_info += f'{partition} '
47
    command_info += f'{config_name} '
48
    command_info += f'{cfg} '
49
    command_info += f'--cfg-options ' \
50
                    f'checkpoint_config.max_keep_ckpts=1 ' \
51
                    f'dist_params.port={port} '
52
    command_info += f'--work-dir work_dirs/{model_name}/{config_name} '
53
    # Let the script shut up
54
    command_info += '>/dev/null &'
55
56
    commands.append(command_info)
57
    commands.append('\n')
58
59
60
def main():
61
    args = parse_args()
62
    if args.out:
63
        out_suffix = args.out.split('.')[-1]
64
        assert args.out.endswith('.sh'), \
65
            f'Expected out file path suffix is .sh, but get .{out_suffix}'
66
67
    root_name = './tools'
68
    script_name = osp.join(root_name, 'slurm_train.sh')
69
    port = args.port
70
    partition_name = 'PARTITION=$1'
71
72
    commands = []
73
    commands.append(partition_name)
74
    commands.append('\n')
75
    commands.append('\n')
76
77
    with open(args.txt_path, 'r') as f:
78
        model_cfgs = f.readlines()
79
        for i, cfg in enumerate(model_cfgs):
80
            create_train_bash_info(commands, cfg, script_name, '$PARTITION',
81
                                   port)
82
            port += 1
83
84
        command_str = ''.join(commands)
85
        if args.out:
86
            with open(args.out, 'w') as f:
87
                f.write(command_str)
88
89
90
if __name__ == '__main__':
91
    main()