--- a
+++ b/ViTPose/mmpose/apis/train.py
@@ -0,0 +1,200 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
+                         get_dist_info)
+from mmcv.utils import digit_version
+
+from mmpose.core import DistEvalHook, EvalHook, build_optimizers
+from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper
+from mmpose.datasets import build_dataloader, build_dataset
+from mmpose.utils import get_root_logger
+
+try:
+    from mmcv.runner import Fp16OptimizerHook
+except ImportError:
+    warnings.warn(
+        'Fp16OptimizerHook from mmpose will be deprecated from '
+        'v0.15.0. Please install mmcv>=1.1.4', DeprecationWarning)
+    from mmpose.core import Fp16OptimizerHook
+
+
+def init_random_seed(seed=None, device='cuda'):
+    """Initialize random seed.
+
+    If the seed is not set, the seed will be automatically randomized,
+    and then broadcast to all processes to prevent some potential bugs.
+
+    Args:
+        seed (int, Optional): The seed. Default to None.
+        device (str): The device where the seed will be put on.
+            Default to 'cuda'.
+
+    Returns:
+        int: Seed to be used.
+    """
+    if seed is not None:
+        return seed
+
+    # Make sure all ranks share the same random seed to prevent
+    # some potential bugs. Please refer to
+    # https://github.com/open-mmlab/mmdetection/issues/6339
+    rank, world_size = get_dist_info()
+    seed = np.random.randint(2**31)
+    if world_size == 1:
+        return seed
+
+    if rank == 0:
+        random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+    else:
+        random_num = torch.tensor(0, dtype=torch.int32, device=device)
+    dist.broadcast(random_num, src=0)
+    return random_num.item()
+
+
+def train_model(model,
+                dataset,
+                cfg,
+                distributed=False,
+                validate=False,
+                timestamp=None,
+                meta=None):
+    """Train model entry function.
+
+    Args:
+        model (nn.Module): The model to be trained.
+        dataset (Dataset): Train dataset.
+        cfg (dict): The config dict for training.
+        distributed (bool): Whether to use distributed training.
+            Default: False.
+        validate (bool): Whether to do evaluation. Default: False.
+        timestamp (str | None): Local time for runner. Default: None.
+        meta (dict | None): Meta dict to record some important information.
+            Default: None
+    """
+    logger = get_root_logger(cfg.log_level)
+
+    # prepare data loaders
+    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+    # step 1: give default values and override (if exist) from cfg.data
+    loader_cfg = {
+        **dict(
+            seed=cfg.get('seed'),
+            drop_last=False,
+            dist=distributed,
+            num_gpus=len(cfg.gpu_ids)),
+        **({} if torch.__version__ != 'parrots' else dict(
+               prefetch_num=2,
+               pin_memory=False,
+           )),
+        **dict((k, cfg.data[k]) for k in [
+                   'samples_per_gpu',
+                   'workers_per_gpu',
+                   'shuffle',
+                   'seed',
+                   'drop_last',
+                   'prefetch_num',
+                   'pin_memory',
+                   'persistent_workers',
+               ] if k in cfg.data)
+    }
+
+    # step 2: cfg.data.train_dataloader has highest priority
+    train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
+
+    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
+
+    # determine whether use adversarial training precess or not
+    use_adverserial_train = cfg.get('use_adversarial_train', False)
+
+    # put model on gpus
+    if distributed:
+        find_unused_parameters = cfg.get('find_unused_parameters', False)
+        # Sets the `find_unused_parameters` parameter in
+        # torch.nn.parallel.DistributedDataParallel
+
+        if use_adverserial_train:
+            # Use DistributedDataParallelWrapper for adversarial training
+            model = DistributedDataParallelWrapper(
+                model,
+                device_ids=[torch.cuda.current_device()],
+                broadcast_buffers=False,
+                find_unused_parameters=find_unused_parameters)
+        else:
+            model = MMDistributedDataParallel(
+                model.cuda(),
+                device_ids=[torch.cuda.current_device()],
+                broadcast_buffers=False,
+                find_unused_parameters=find_unused_parameters)
+    else:
+        if digit_version(mmcv.__version__) >= digit_version(
+                '1.4.4') or torch.cuda.is_available():
+            model = MMDataParallel(model, device_ids=cfg.gpu_ids)
+        else:
+            warnings.warn(
+                'We recommend to use MMCV >= 1.4.4 for CPU training. '
+                'See https://github.com/open-mmlab/mmpose/pull/1157 for '
+                'details.')
+
+    # build runner
+    optimizer = build_optimizers(model, cfg.optimizer)
+
+    runner = EpochBasedRunner(
+        model,
+        optimizer=optimizer,
+        work_dir=cfg.work_dir,
+        logger=logger,
+        meta=meta)
+    # an ugly workaround to make .log and .log.json filenames the same
+    runner.timestamp = timestamp
+
+    if use_adverserial_train:
+        # The optimizer step process is included in the train_step function
+        # of the model, so the runner should NOT include optimizer hook.
+        optimizer_config = None
+    else:
+        # fp16 setting
+        fp16_cfg = cfg.get('fp16', None)
+        if fp16_cfg is not None:
+            optimizer_config = Fp16OptimizerHook(
+                **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
+        elif distributed and 'type' not in cfg.optimizer_config:
+            optimizer_config = OptimizerHook(**cfg.optimizer_config)
+        else:
+            optimizer_config = cfg.optimizer_config
+
+    # register hooks
+    runner.register_training_hooks(cfg.lr_config, optimizer_config,
+                                   cfg.checkpoint_config, cfg.log_config,
+                                   cfg.get('momentum_config', None))
+    if distributed:
+        runner.register_hook(DistSamplerSeedHook())
+
+    # register eval hooks
+    if validate:
+        eval_cfg = cfg.get('evaluation', {})
+        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+        dataloader_setting = dict(
+            samples_per_gpu=1,
+            workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
+            # cfg.gpus will be ignored if distributed
+            num_gpus=len(cfg.gpu_ids),
+            dist=distributed,
+            drop_last=False,
+            shuffle=False)
+        dataloader_setting = dict(dataloader_setting,
+                                  **cfg.data.get('val_dataloader', {}))
+        val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
+        eval_hook = DistEvalHook if distributed else EvalHook
+        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
+
+    if cfg.resume_from:
+        runner.resume(cfg.resume_from)
+    elif cfg.load_from:
+        runner.load_checkpoint(cfg.load_from)
+    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)