--- a
+++ b/ViTPose/tools/train.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import os
+import os.path as osp
+import time
+import warnings
+
+import mmcv
+import torch
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist, set_random_seed
+from mmcv.utils import get_git_hash
+
+from mmpose import __version__
+from mmpose.apis import init_random_seed, train_model
+from mmpose.datasets import build_dataset
+from mmpose.models import build_posenet
+from mmpose.utils import collect_env, get_root_logger, setup_multi_processes
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train a pose model')
+    parser.add_argument('config', help='train config file path')
+    parser.add_argument('--work-dir', help='the dir to save logs and models')
+    parser.add_argument(
+        '--resume-from', help='the checkpoint file to resume from')
+    parser.add_argument(
+        '--no-validate',
+        action='store_true',
+        help='whether not to evaluate the checkpoint during training')
+    group_gpus = parser.add_mutually_exclusive_group()
+    group_gpus.add_argument(
+        '--gpus',
+        type=int,
+        help='(Deprecated, please use --gpu-id) number of gpus to use '
+        '(only applicable to non-distributed training)')
+    group_gpus.add_argument(
+        '--gpu-ids',
+        type=int,
+        nargs='+',
+        help='(Deprecated, please use --gpu-id) ids of gpus to use '
+        '(only applicable to non-distributed training)')
+    group_gpus.add_argument(
+        '--gpu-id',
+        type=int,
+        default=0,
+        help='id of gpu to use '
+        '(only applicable to non-distributed training)')
+    parser.add_argument('--seed', type=int, default=None, help='random seed')
+    parser.add_argument(
+        '--deterministic',
+        action='store_true',
+        help='whether to set deterministic options for CUDNN backend.')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        default={},
+        help='override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. For example, '
+        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
+    parser.add_argument(
+        '--launcher',
+        choices=['none', 'pytorch', 'slurm', 'mpi'],
+        default='none',
+        help='job launcher')
+    parser.add_argument('--local_rank', type=int, default=0)
+    parser.add_argument(
+        '--autoscale-lr',
+        action='store_true',
+        help='automatically scale lr with the number of gpus')
+    args = parser.parse_args()
+    if 'LOCAL_RANK' not in os.environ:
+        os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    cfg = Config.fromfile(args.config)
+
+    if args.cfg_options is not None:
+        cfg.merge_from_dict(args.cfg_options)
+
+    # set multi-process settings
+    setup_multi_processes(cfg)
+
+    # set cudnn_benchmark
+    if cfg.get('cudnn_benchmark', False):
+        torch.backends.cudnn.benchmark = True
+
+    # work_dir is determined in this priority: CLI > segment in file > filename
+    if args.work_dir is not None:
+        # update configs according to CLI args if args.work_dir is not None
+        cfg.work_dir = args.work_dir
+    elif cfg.get('work_dir', None) is None:
+        # use config filename as default work_dir if cfg.work_dir is None
+        cfg.work_dir = osp.join('./work_dirs',
+                                osp.splitext(osp.basename(args.config))[0])
+    if args.resume_from is not None:
+        cfg.resume_from = args.resume_from
+    if args.gpus is not None:
+        cfg.gpu_ids = range(1)
+        warnings.warn('`--gpus` is deprecated because we only support '
+                      'single GPU mode in non-distributed training. '
+                      'Use `gpus=1` now.')
+    if args.gpu_ids is not None:
+        cfg.gpu_ids = args.gpu_ids[0:1]
+        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
+                      'Because we only support single GPU mode in '
+                      'non-distributed training. Use the first GPU '
+                      'in `gpu_ids` now.')
+    if args.gpus is None and args.gpu_ids is None:
+        cfg.gpu_ids = [args.gpu_id]
+
+    if args.autoscale_lr:
+        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
+        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
+
+    # init distributed env first, since logger depends on the dist info.
+    if args.launcher == 'none':
+        distributed = False
+        if len(cfg.gpu_ids) > 1:
+            warnings.warn(
+                f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '
+                f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '
+                'non-distribute training time.')
+            cfg.gpu_ids = cfg.gpu_ids[0:1]
+    else:
+        distributed = True
+        init_dist(args.launcher, **cfg.dist_params)
+        # re-set gpu_ids with distributed training mode
+        _, world_size = get_dist_info()
+        cfg.gpu_ids = range(world_size)
+
+    # create work_dir
+    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+    # init the logger before other steps
+    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
+    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
+
+    # init the meta dict to record some important information such as
+    # environment info and seed, which will be logged
+    meta = dict()
+    # log env info
+    env_info_dict = collect_env()
+    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+    dash_line = '-' * 60 + '\n'
+    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+                dash_line)
+    meta['env_info'] = env_info
+
+    # log some basic info
+    logger.info(f'Distributed training: {distributed}')
+    logger.info(f'Config:\n{cfg.pretty_text}')
+
+    # set random seeds
+    seed = init_random_seed(args.seed)
+    logger.info(f'Set random seed to {seed}, '
+                f'deterministic: {args.deterministic}')
+    set_random_seed(seed, deterministic=args.deterministic)
+    cfg.seed = seed
+    meta['seed'] = seed
+
+    model = build_posenet(cfg.model)
+    datasets = [build_dataset(cfg.data.train)]
+
+    if len(cfg.workflow) == 2:
+        val_dataset = copy.deepcopy(cfg.data.val)
+        val_dataset.pipeline = cfg.data.train.pipeline
+        datasets.append(build_dataset(val_dataset))
+
+    if cfg.checkpoint_config is not None:
+        # save mmpose version, config file content
+        # checkpoints as meta data
+        cfg.checkpoint_config.meta = dict(
+            mmpose_version=__version__ + get_git_hash(digits=7),
+            config=cfg.pretty_text,
+        )
+    train_model(
+        model,
+        datasets,
+        cfg,
+        distributed=distributed,
+        validate=(not args.no_validate),
+        timestamp=timestamp,
+        meta=meta)
+
+
+if __name__ == '__main__':
+    main()