Diff of /ViTPose/tools/train.py [000000] .. [c1b1c5]

Switch to unified view

a b/ViTPose/tools/train.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import copy
4
import os
5
import os.path as osp
6
import time
7
import warnings
8
9
import mmcv
10
import torch
11
from mmcv import Config, DictAction
12
from mmcv.runner import get_dist_info, init_dist, set_random_seed
13
from mmcv.utils import get_git_hash
14
15
from mmpose import __version__
16
from mmpose.apis import init_random_seed, train_model
17
from mmpose.datasets import build_dataset
18
from mmpose.models import build_posenet
19
from mmpose.utils import collect_env, get_root_logger, setup_multi_processes
20
21
22
def parse_args():
23
    parser = argparse.ArgumentParser(description='Train a pose model')
24
    parser.add_argument('config', help='train config file path')
25
    parser.add_argument('--work-dir', help='the dir to save logs and models')
26
    parser.add_argument(
27
        '--resume-from', help='the checkpoint file to resume from')
28
    parser.add_argument(
29
        '--no-validate',
30
        action='store_true',
31
        help='whether not to evaluate the checkpoint during training')
32
    group_gpus = parser.add_mutually_exclusive_group()
33
    group_gpus.add_argument(
34
        '--gpus',
35
        type=int,
36
        help='(Deprecated, please use --gpu-id) number of gpus to use '
37
        '(only applicable to non-distributed training)')
38
    group_gpus.add_argument(
39
        '--gpu-ids',
40
        type=int,
41
        nargs='+',
42
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
43
        '(only applicable to non-distributed training)')
44
    group_gpus.add_argument(
45
        '--gpu-id',
46
        type=int,
47
        default=0,
48
        help='id of gpu to use '
49
        '(only applicable to non-distributed training)')
50
    parser.add_argument('--seed', type=int, default=None, help='random seed')
51
    parser.add_argument(
52
        '--deterministic',
53
        action='store_true',
54
        help='whether to set deterministic options for CUDNN backend.')
55
    parser.add_argument(
56
        '--cfg-options',
57
        nargs='+',
58
        action=DictAction,
59
        default={},
60
        help='override some settings in the used config, the key-value pair '
61
        'in xxx=yyy format will be merged into config file. For example, '
62
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
63
    parser.add_argument(
64
        '--launcher',
65
        choices=['none', 'pytorch', 'slurm', 'mpi'],
66
        default='none',
67
        help='job launcher')
68
    parser.add_argument('--local_rank', type=int, default=0)
69
    parser.add_argument(
70
        '--autoscale-lr',
71
        action='store_true',
72
        help='automatically scale lr with the number of gpus')
73
    args = parser.parse_args()
74
    if 'LOCAL_RANK' not in os.environ:
75
        os.environ['LOCAL_RANK'] = str(args.local_rank)
76
77
    return args
78
79
80
def main():
81
    args = parse_args()
82
83
    cfg = Config.fromfile(args.config)
84
85
    if args.cfg_options is not None:
86
        cfg.merge_from_dict(args.cfg_options)
87
88
    # set multi-process settings
89
    setup_multi_processes(cfg)
90
91
    # set cudnn_benchmark
92
    if cfg.get('cudnn_benchmark', False):
93
        torch.backends.cudnn.benchmark = True
94
95
    # work_dir is determined in this priority: CLI > segment in file > filename
96
    if args.work_dir is not None:
97
        # update configs according to CLI args if args.work_dir is not None
98
        cfg.work_dir = args.work_dir
99
    elif cfg.get('work_dir', None) is None:
100
        # use config filename as default work_dir if cfg.work_dir is None
101
        cfg.work_dir = osp.join('./work_dirs',
102
                                osp.splitext(osp.basename(args.config))[0])
103
    if args.resume_from is not None:
104
        cfg.resume_from = args.resume_from
105
    if args.gpus is not None:
106
        cfg.gpu_ids = range(1)
107
        warnings.warn('`--gpus` is deprecated because we only support '
108
                      'single GPU mode in non-distributed training. '
109
                      'Use `gpus=1` now.')
110
    if args.gpu_ids is not None:
111
        cfg.gpu_ids = args.gpu_ids[0:1]
112
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
113
                      'Because we only support single GPU mode in '
114
                      'non-distributed training. Use the first GPU '
115
                      'in `gpu_ids` now.')
116
    if args.gpus is None and args.gpu_ids is None:
117
        cfg.gpu_ids = [args.gpu_id]
118
119
    if args.autoscale_lr:
120
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
121
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
122
123
    # init distributed env first, since logger depends on the dist info.
124
    if args.launcher == 'none':
125
        distributed = False
126
        if len(cfg.gpu_ids) > 1:
127
            warnings.warn(
128
                f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '
129
                f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '
130
                'non-distribute training time.')
131
            cfg.gpu_ids = cfg.gpu_ids[0:1]
132
    else:
133
        distributed = True
134
        init_dist(args.launcher, **cfg.dist_params)
135
        # re-set gpu_ids with distributed training mode
136
        _, world_size = get_dist_info()
137
        cfg.gpu_ids = range(world_size)
138
139
    # create work_dir
140
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
141
    # init the logger before other steps
142
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
143
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
144
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
145
146
    # init the meta dict to record some important information such as
147
    # environment info and seed, which will be logged
148
    meta = dict()
149
    # log env info
150
    env_info_dict = collect_env()
151
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
152
    dash_line = '-' * 60 + '\n'
153
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
154
                dash_line)
155
    meta['env_info'] = env_info
156
157
    # log some basic info
158
    logger.info(f'Distributed training: {distributed}')
159
    logger.info(f'Config:\n{cfg.pretty_text}')
160
161
    # set random seeds
162
    seed = init_random_seed(args.seed)
163
    logger.info(f'Set random seed to {seed}, '
164
                f'deterministic: {args.deterministic}')
165
    set_random_seed(seed, deterministic=args.deterministic)
166
    cfg.seed = seed
167
    meta['seed'] = seed
168
169
    model = build_posenet(cfg.model)
170
    datasets = [build_dataset(cfg.data.train)]
171
172
    if len(cfg.workflow) == 2:
173
        val_dataset = copy.deepcopy(cfg.data.val)
174
        val_dataset.pipeline = cfg.data.train.pipeline
175
        datasets.append(build_dataset(val_dataset))
176
177
    if cfg.checkpoint_config is not None:
178
        # save mmpose version, config file content
179
        # checkpoints as meta data
180
        cfg.checkpoint_config.meta = dict(
181
            mmpose_version=__version__ + get_git_hash(digits=7),
182
            config=cfg.pretty_text,
183
        )
184
    train_model(
185
        model,
186
        datasets,
187
        cfg,
188
        distributed=distributed,
189
        validate=(not args.no_validate),
190
        timestamp=timestamp,
191
        meta=meta)
192
193
194
if __name__ == '__main__':
195
    main()