a b/minigpt4/common/dist_utils.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import datetime
9
import functools
10
import os
11
12
import torch
13
import torch.distributed as dist
14
import timm.models.hub as timm_hub
15
16
17
def setup_for_distributed(is_master):
18
    """
19
    This function disables printing when not in master process
20
    """
21
    import builtins as __builtin__
22
23
    builtin_print = __builtin__.print
24
25
    def print(*args, **kwargs):
26
        force = kwargs.pop("force", False)
27
        if is_master or force:
28
            builtin_print(*args, **kwargs)
29
30
    __builtin__.print = print
31
32
33
def is_dist_avail_and_initialized():
34
    if not dist.is_available():
35
        return False
36
    if not dist.is_initialized():
37
        return False
38
    return True
39
40
41
def get_world_size():
42
    if not is_dist_avail_and_initialized():
43
        return 1
44
    return dist.get_world_size()
45
46
47
def get_rank():
48
    if not is_dist_avail_and_initialized():
49
        return 0
50
    return dist.get_rank()
51
52
53
def is_main_process():
54
    return get_rank() == 0
55
56
57
def init_distributed_mode(args):
58
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
        args.rank = int(os.environ["RANK"])
60
        args.world_size = int(os.environ["WORLD_SIZE"])
61
        args.gpu = int(os.environ["LOCAL_RANK"])
62
    elif "SLURM_PROCID" in os.environ:
63
        args.rank = int(os.environ["SLURM_PROCID"])
64
        args.gpu = args.rank % torch.cuda.device_count()
65
    else:
66
        print("Not using distributed mode")
67
        args.distributed = False
68
        return
69
70
    args.distributed = True
71
72
    torch.cuda.set_device(args.gpu)
73
    args.dist_backend = "nccl"
74
    print(
75
        "| distributed init (rank {}, world {}): {}".format(
76
            args.rank, args.world_size, args.dist_url
77
        ),
78
        flush=True,
79
    )
80
    # use zero optimizer for distributed initialization
81
    if args.use_zero_optimizer:
82
        print("Using ZeRO optimizer distributed mode.")
83
        import deepspeed
84
        deepspeed.init_distributed(
85
            dist_backend=args.dist_backend,
86
            init_method=args.dist_url,
87
            rank=args.rank,
88
            timeout=datetime.timedelta(days=365),  # allow auto-downloading and de-compressing,
89
            # config=args.deepspeed_config,
90
        )
91
    # use pytorch distributed initialization
92
    else:
93
        print("Using PyTorch optimizer distributed mode.")
94
        torch.distributed.init_process_group(
95
            backend=args.dist_backend,
96
            init_method=args.dist_url,
97
            world_size=args.world_size,
98
            rank=args.rank,
99
            timeout=datetime.timedelta(
100
                days=365
101
            ),  # allow auto-downloading and de-compressing
102
        )
103
    torch.distributed.barrier()
104
    setup_for_distributed(args.rank == 0)
105
106
107
def get_dist_info():
108
    if torch.__version__ < "1.0":
109
        initialized = dist._initialized
110
    else:
111
        initialized = dist.is_initialized()
112
    if initialized:
113
        rank = dist.get_rank()
114
        world_size = dist.get_world_size()
115
    else:  # non-distributed training
116
        rank = 0
117
        world_size = 1
118
    return rank, world_size
119
120
121
def main_process(func):
122
    @functools.wraps(func)
123
    def wrapper(*args, **kwargs):
124
        rank, _ = get_dist_info()
125
        if rank == 0:
126
            return func(*args, **kwargs)
127
128
    return wrapper
129
130
131
def download_cached_file(url, check_hash=True, progress=False):
132
    """
133
    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
134
    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
135
    """
136
137
    def get_cached_file_path():
138
        # a hack to sync the file path across processes
139
        parts = torch.hub.urlparse(url)
140
        filename = os.path.basename(parts.path)
141
        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
142
143
        return cached_file
144
145
    if is_main_process():
146
        timm_hub.download_cached_file(url, check_hash, progress)
147
148
    if is_dist_avail_and_initialized():
149
        dist.barrier()
150
151
    return get_cached_file_path()