--- a
+++ b/ViTPose/mmcv_custom/checkpoint.py
@@ -0,0 +1,539 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+from torch.nn import functional as F
+
+import mmcv
+from mmcv.fileio import FileClient
+from mmcv.fileio import load as load_file
+from mmcv.parallel import is_module_wrapper
+from mmcv.utils import mkdir_or_exist
+from mmcv.runner import get_dist_info
+
+from scipy import interpolate
+import numpy as np
+import math
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+    mmcv_home = os.path.expanduser(
+        os.getenv(
+            ENV_MMCV_HOME,
+            os.path.join(
+                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+    mkdir_or_exist(mmcv_home)
+    return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+    """Load state_dict to a module.
+
+    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+    Default value for ``strict`` is set to ``False`` and the message for
+    param mismatch will be shown even if strict is False.
+
+    Args:
+        module (Module): Module that receives the state_dict.
+        state_dict (OrderedDict): Weights.
+        strict (bool): whether to strictly enforce that the keys
+            in :attr:`state_dict` match the keys returned by this module's
+            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+        logger (:obj:`logging.Logger`, optional): Logger to log the error
+            message. If not specified, print function will be used.
+    """
+    unexpected_keys = []
+    all_missing_keys = []
+    err_msg = []
+
+    metadata = getattr(state_dict, '_metadata', None)
+    state_dict = state_dict.copy()
+    if metadata is not None:
+        state_dict._metadata = metadata
+
+    # use _load_from_state_dict to enable checkpoint version control
+    def load(module, prefix=''):
+        # recursively check parallel module in case that the model has a
+        # complicated structure, e.g., nn.Module(nn.Module(DDP))
+        if is_module_wrapper(module):
+            module = module.module
+        local_metadata = {} if metadata is None else metadata.get(
+            prefix[:-1], {})
+        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+                                     all_missing_keys, unexpected_keys,
+                                     err_msg)
+        for name, child in module._modules.items():
+            if child is not None:
+                load(child, prefix + name + '.')
+
+    load(module)
+    load = None  # break load->load reference cycle
+
+    # ignore "num_batches_tracked" of BN layers
+    missing_keys = [
+        key for key in all_missing_keys if 'num_batches_tracked' not in key
+    ]
+
+    if unexpected_keys:
+        err_msg.append('unexpected key in source '
+                       f'state_dict: {", ".join(unexpected_keys)}\n')
+    if missing_keys:
+        err_msg.append(
+            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+    rank, _ = get_dist_info()
+    if len(err_msg) > 0 and rank == 0:
+        err_msg.insert(
+            0, 'The model and loaded state dict do not match exactly\n')
+        err_msg = '\n'.join(err_msg)
+        if strict:
+            raise RuntimeError(err_msg)
+        elif logger is not None:
+            logger.warning(err_msg)
+        else:
+            print(err_msg)
+
+
+def load_url_dist(url, model_dir=None, map_location="cpu"):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    if rank == 0:
+        checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
+    return checkpoint
+
+
+def load_pavimodel_dist(model_path, map_location=None):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    try:
+        from pavi import modelcloud
+    except ImportError:
+        raise ImportError(
+            'Please install pavi to load checkpoint from modelcloud.')
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    if rank == 0:
+        model = modelcloud.get(model_path)
+        with TemporaryDirectory() as tmp_dir:
+            downloaded_file = osp.join(tmp_dir, model.name)
+            model.download(downloaded_file)
+            checkpoint = torch.load(downloaded_file, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            model = modelcloud.get(model_path)
+            with TemporaryDirectory() as tmp_dir:
+                downloaded_file = osp.join(tmp_dir, model.name)
+                model.download(downloaded_file)
+                checkpoint = torch.load(
+                    downloaded_file, map_location=map_location)
+    return checkpoint
+
+
+def load_fileclient_dist(filename, backend, map_location):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    allowed_backends = ['ceph']
+    if backend not in allowed_backends:
+        raise ValueError(f'Load from Backend {backend} is not supported.')
+    if rank == 0:
+        fileclient = FileClient(backend=backend)
+        buffer = io.BytesIO(fileclient.get(filename))
+        checkpoint = torch.load(buffer, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            fileclient = FileClient(backend=backend)
+            buffer = io.BytesIO(fileclient.get(filename))
+            checkpoint = torch.load(buffer, map_location=map_location)
+    return checkpoint
+
+
+def get_torchvision_models():
+    model_urls = dict()
+    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+        if ispkg:
+            continue
+        _zoo = import_module(f'torchvision.models.{name}')
+        if hasattr(_zoo, 'model_urls'):
+            _urls = getattr(_zoo, 'model_urls')
+            model_urls.update(_urls)
+    return model_urls
+
+
+def get_external_models():
+    mmcv_home = _get_mmcv_home()
+    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+    default_urls = load_file(default_json_path)
+    assert isinstance(default_urls, dict)
+    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+    if osp.exists(external_json_path):
+        external_urls = load_file(external_json_path)
+        assert isinstance(external_urls, dict)
+        default_urls.update(external_urls)
+
+    return default_urls
+
+
+def get_mmcls_models():
+    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+    mmcls_urls = load_file(mmcls_json_path)
+
+    return mmcls_urls
+
+
+def get_deprecated_model_names():
+    deprecate_json_path = osp.join(mmcv.__path__[0],
+                                   'model_zoo/deprecated.json')
+    deprecate_urls = load_file(deprecate_json_path)
+    assert isinstance(deprecate_urls, dict)
+
+    return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+    state_dict = checkpoint['state_dict']
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        if k.startswith('backbone.'):
+            new_state_dict[k[9:]] = v
+    new_checkpoint = dict(state_dict=new_state_dict)
+
+    return new_checkpoint
+
+
+def _load_checkpoint(filename, map_location=None):
+    """Load checkpoint from somewhere (modelzoo, file, url).
+
+    Args:
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+    Returns:
+        dict | OrderedDict: The loaded checkpoint. It can be either an
+            OrderedDict storing model models or a dict containing other
+            information, which depends on the checkpoint.
+    """
+    if filename.startswith('modelzoo://'):
+        warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+                      'use "torchvision://" instead')
+        model_urls = get_torchvision_models()
+        model_name = filename[11:]
+        checkpoint = load_url_dist(model_urls[model_name])
+    elif filename.startswith('torchvision://'):
+        model_urls = get_torchvision_models()
+        model_name = filename[14:]
+        checkpoint = load_url_dist(model_urls[model_name])
+    elif filename.startswith('open-mmlab://'):
+        model_urls = get_external_models()
+        model_name = filename[13:]
+        deprecated_urls = get_deprecated_model_names()
+        if model_name in deprecated_urls:
+            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
+                          f'of open-mmlab://{deprecated_urls[model_name]}')
+            model_name = deprecated_urls[model_name]
+        model_url = model_urls[model_name]
+        # check if is url
+        if model_url.startswith(('http://', 'https://')):
+            checkpoint = load_url_dist(model_url)
+        else:
+            filename = osp.join(_get_mmcv_home(), model_url)
+            if not osp.isfile(filename):
+                raise IOError(f'{filename} is not a checkpoint file')
+            checkpoint = torch.load(filename, map_location=map_location)
+    elif filename.startswith('mmcls://'):
+        model_urls = get_mmcls_models()
+        model_name = filename[8:]
+        checkpoint = load_url_dist(model_urls[model_name])
+        checkpoint = _process_mmcls_checkpoint(checkpoint)
+    elif filename.startswith(('http://', 'https://')):
+        checkpoint = load_url_dist(filename)
+    elif filename.startswith('pavi://'):
+        model_path = filename[7:]
+        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+    elif filename.startswith('s3://'):
+        checkpoint = load_fileclient_dist(
+            filename, backend='ceph', map_location=map_location)
+    else:
+        if not osp.isfile(filename):
+            raise IOError(f'{filename} is not a checkpoint file')
+        checkpoint = torch.load(filename, map_location=map_location)
+    return checkpoint
+
+
+def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
+                     start_warmup_value=0, warmup_steps=-1):
+    warmup_schedule = np.array([])
+    warmup_iters = warmup_epochs * niter_per_ep
+    if warmup_steps > 0:
+        warmup_iters = warmup_steps
+    print("Set warmup steps = %d" % warmup_iters)
+    if warmup_epochs > 0:
+        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+    iters = np.arange(epochs * niter_per_ep - warmup_iters)
+    schedule = np.array(
+        [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
+
+    schedule = np.concatenate((warmup_schedule, schedule))
+
+    assert len(schedule) == epochs * niter_per_ep
+    return schedule
+
+
+def load_checkpoint(model,
+                    filename,
+                    map_location='cpu',
+                    strict=False,
+                    logger=None,
+                    patch_padding='pad',
+                    ):
+    """Load checkpoint from a file or URI.
+
+    Args:
+        model (Module): Module to load checkpoint.
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str): Same as :func:`torch.load`.
+        strict (bool): Whether to allow different params for the model and
+            checkpoint.
+        logger (:mod:`logging.Logger` or None): The logger for error message.
+        patch_padding (str): 'pad' or 'bilinear' or 'bicubic', used for interpolate patch embed from 14x14 to 16x16
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    checkpoint = _load_checkpoint(filename, map_location)
+    # OrderedDict is a subclass of dict
+    if not isinstance(checkpoint, dict):
+        raise RuntimeError(
+            f'No state_dict found in checkpoint file {filename}')
+    # get state_dict from checkpoint
+    if 'state_dict' in checkpoint:
+        state_dict = checkpoint['state_dict']
+    elif 'model' in checkpoint:
+        state_dict = checkpoint['model']
+    elif 'module' in checkpoint:
+        state_dict = checkpoint['module']
+    else:
+        state_dict = checkpoint
+    # strip prefix of state_dict
+    if list(state_dict.keys())[0].startswith('module.'):
+        state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+    # for MoBY, load model of online branch
+    if sorted(list(state_dict.keys()))[0].startswith('encoder'):
+        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
+
+    rank, _ = get_dist_info()
+
+    if 'patch_embed.proj.weight' in state_dict:
+        proj_weight = state_dict['patch_embed.proj.weight']
+        orig_size = proj_weight.shape[2:]
+        current_size = model.patch_embed.proj.weight.shape[2:]
+        padding_size = current_size[0] - orig_size[0]
+        padding_l = padding_size // 2
+        padding_r = padding_size - padding_l
+        if orig_size != current_size:
+            if 'pad' in patch_padding:
+                proj_weight = torch.nn.functional.pad(proj_weight, (padding_l, padding_r, padding_l, padding_r))
+            elif 'bilinear' in patch_padding:
+                proj_weight = torch.nn.functional.interpolate(proj_weight, size=current_size, mode='bilinear', align_corners=False)
+            elif 'bicubic' in patch_padding:
+                proj_weight = torch.nn.functional.interpolate(proj_weight, size=current_size, mode='bicubic', align_corners=False)
+            state_dict['patch_embed.proj.weight'] = proj_weight
+
+    if 'pos_embed' in state_dict:
+        pos_embed_checkpoint = state_dict['pos_embed']
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        H, W = model.patch_embed.patch_shape
+        num_patches = model.patch_embed.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+        # height (== width) for the checkpoint position embedding
+        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+        if rank == 0:
+            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, H, W))
+        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+        # only the position tokens are interpolated
+        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+        pos_tokens = torch.nn.functional.interpolate(
+            pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
+        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+        state_dict['pos_embed'] = new_pos_embed
+
+    # load state_dict
+    load_state_dict(model, state_dict, strict, logger)
+    return checkpoint
+
+
+def weights_to_cpu(state_dict):
+    """Copy a model state_dict to cpu.
+
+    Args:
+        state_dict (OrderedDict): Model models on GPU.
+
+    Returns:
+        OrderedDict: Model models on GPU.
+    """
+    state_dict_cpu = OrderedDict()
+    for key, val in state_dict.items():
+        state_dict_cpu[key] = val.cpu()
+    return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+    """Saves module state to `destination` dictionary.
+
+    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (dict): A dict where state will be stored.
+        prefix (str): The prefix for parameters and buffers used in this
+            module.
+    """
+    for name, param in module._parameters.items():
+        if param is not None:
+            destination[prefix + name] = param if keep_vars else param.detach()
+    for name, buf in module._buffers.items():
+        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+        if buf is not None:
+            destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+    """Returns a dictionary containing a whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are
+    included. Keys are corresponding parameter and buffer names.
+
+    This method is modified from :meth:`torch.nn.Module.state_dict` to
+    recursively check parallel module in case that the model has a complicated
+    structure, e.g., nn.Module(nn.Module(DDP)).
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (OrderedDict): Returned dict for the state of the
+            module.
+        prefix (str): Prefix of the key.
+        keep_vars (bool): Whether to keep the variable property of the
+            parameters. Default: False.
+
+    Returns:
+        dict: A dictionary containing a whole state of the module.
+    """
+    # recursively check parallel module in case that the model has a
+    # complicated structure, e.g., nn.Module(nn.Module(DDP))
+    if is_module_wrapper(module):
+        module = module.module
+
+    # below is the same as torch.nn.Module.state_dict()
+    if destination is None:
+        destination = OrderedDict()
+        destination._metadata = OrderedDict()
+    destination._metadata[prefix[:-1]] = local_metadata = dict(
+        version=module._version)
+    _save_to_state_dict(module, destination, prefix, keep_vars)
+    for name, child in module._modules.items():
+        if child is not None:
+            get_state_dict(
+                child, destination, prefix + name + '.', keep_vars=keep_vars)
+    for hook in module._state_dict_hooks.values():
+        hook_result = hook(module, destination, prefix, local_metadata)
+        if hook_result is not None:
+            destination = hook_result
+    return destination
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+    """Save checkpoint to file.
+
+    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+    ``optimizer``. By default ``meta`` will contain version and time info.
+
+    Args:
+        model (Module): Module whose params are to be saved.
+        filename (str): Checkpoint filename.
+        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+        meta (dict, optional): Metadata to be saved in checkpoint.
+    """
+    if meta is None:
+        meta = {}
+    elif not isinstance(meta, dict):
+        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+    if is_module_wrapper(model):
+        model = model.module
+
+    if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+        # save class name to the meta
+        meta.update(CLASSES=model.CLASSES)
+
+    checkpoint = {
+        'meta': meta,
+        'state_dict': weights_to_cpu(get_state_dict(model))
+    }
+    # save optimizer state dict in the checkpoint
+    if isinstance(optimizer, Optimizer):
+        checkpoint['optimizer'] = optimizer.state_dict()
+    elif isinstance(optimizer, dict):
+        checkpoint['optimizer'] = {}
+        for name, optim in optimizer.items():
+            checkpoint['optimizer'][name] = optim.state_dict()
+
+    if filename.startswith('pavi://'):
+        try:
+            from pavi import modelcloud
+            from pavi.exception import NodeNotFoundError
+        except ImportError:
+            raise ImportError(
+                'Please install pavi to load checkpoint from modelcloud.')
+        model_path = filename[7:]
+        root = modelcloud.Folder()
+        model_dir, model_name = osp.split(model_path)
+        try:
+            model = modelcloud.get(model_dir)
+        except NodeNotFoundError:
+            model = root.create_training_model(model_dir)
+        with TemporaryDirectory() as tmp_dir:
+            checkpoint_file = osp.join(tmp_dir, model_name)
+            with open(checkpoint_file, 'wb') as f:
+                torch.save(checkpoint, f)
+                f.flush()
+            model.create_file(checkpoint_file, name=model_name)
+    else:
+        mmcv.mkdir_or_exist(osp.dirname(filename))
+        # immediately flush buffer
+        with open(filename, 'wb') as f:
+            torch.save(checkpoint, f)
+            f.flush()