Switch to side-by-side view

--- a
+++ b/model/lavis/common/dist_utils.py
@@ -0,0 +1,138 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import timm.models.hub as timm_hub
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop("force", False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def init_distributed_mode(args):
+    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.gpu = int(os.environ["LOCAL_RANK"])
+    elif "SLURM_PROCID" in os.environ:
+        args.rank = int(os.environ["SLURM_PROCID"])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print("Not using distributed mode")
+        args.distributed = False
+        return
+
+    args.dist_url = args.run_cfg.dist_url
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = "nccl"
+    print(
+        "| distributed init (rank {}, world {}): {}".format(
+            args.rank, args.world_size, args.dist_url
+        ),
+        flush=True,
+    )
+    torch.distributed.init_process_group(
+        backend=args.dist_backend,
+        init_method=args.dist_url,
+        world_size=args.world_size,
+        rank=args.rank,
+        timeout=datetime.timedelta(
+            days=365
+        ),  # allow auto-downloading and de-compressing
+    )
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
+
+
+def get_dist_info():
+    if torch.__version__ < "1.0":
+        initialized = dist._initialized
+    else:
+        initialized = dist.is_initialized()
+    if initialized:
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    else:  # non-distributed training
+        rank = 0
+        world_size = 1
+    return rank, world_size
+
+
+def main_process(func):
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        rank, _ = get_dist_info()
+        if rank == 0:
+            return func(*args, **kwargs)
+
+    return wrapper
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+    """
+    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
+    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
+    """
+
+    def get_cached_file_path():
+        # a hack to sync the file path across processes
+        parts = torch.hub.urlparse(url)
+        filename = os.path.basename(parts.path)
+        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
+
+        return cached_file
+
+    if is_main_process():
+        timm_hub.download_cached_file(url, check_hash, progress)
+
+    if is_dist_avail_and_initialized():
+        dist.barrier()
+
+    return get_cached_file_path()