Switch to unified view

a b/model/lavis/common/logger.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import datetime
9
import logging
10
import time
11
from collections import defaultdict, deque
12
13
import torch
14
import torch.distributed as dist
15
16
from model.lavis.common import dist_utils
17
18
19
class SmoothedValue(object):
20
    """Track a series of values and provide access to smoothed values over a
21
    window or the global series average.
22
    """
23
24
    def __init__(self, window_size=20, fmt=None):
25
        if fmt is None:
26
            fmt = "{median:.4f} ({global_avg:.4f})"
27
        self.deque = deque(maxlen=window_size)
28
        self.total = 0.0
29
        self.count = 0
30
        self.fmt = fmt
31
32
    def update(self, value, n=1):
33
        self.deque.append(value)
34
        self.count += n
35
        self.total += value * n
36
37
    def synchronize_between_processes(self):
38
        """
39
        Warning: does not synchronize the deque!
40
        """
41
        if not dist_utils.is_dist_avail_and_initialized():
42
            return
43
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
        dist.barrier()
45
        dist.all_reduce(t)
46
        t = t.tolist()
47
        self.count = int(t[0])
48
        self.total = t[1]
49
50
    @property
51
    def median(self):
52
        d = torch.tensor(list(self.deque))
53
        return d.median().item()
54
55
    @property
56
    def avg(self):
57
        d = torch.tensor(list(self.deque), dtype=torch.float32)
58
        return d.mean().item()
59
60
    @property
61
    def global_avg(self):
62
        return self.total / self.count
63
64
    @property
65
    def max(self):
66
        return max(self.deque)
67
68
    @property
69
    def value(self):
70
        return self.deque[-1]
71
72
    def __str__(self):
73
        return self.fmt.format(
74
            median=self.median,
75
            avg=self.avg,
76
            global_avg=self.global_avg,
77
            max=self.max,
78
            value=self.value,
79
        )
80
81
82
class MetricLogger(object):
83
    def __init__(self, delimiter="\t"):
84
        self.meters = defaultdict(SmoothedValue)
85
        self.delimiter = delimiter
86
87
    def update(self, **kwargs):
88
        for k, v in kwargs.items():
89
            if isinstance(v, torch.Tensor):
90
                v = v.item()
91
            assert isinstance(v, (float, int))
92
            self.meters[k].update(v)
93
94
    def __getattr__(self, attr):
95
        if attr in self.meters:
96
            return self.meters[attr]
97
        if attr in self.__dict__:
98
            return self.__dict__[attr]
99
        raise AttributeError(
100
            "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
        )
102
103
    def __str__(self):
104
        loss_str = []
105
        for name, meter in self.meters.items():
106
            loss_str.append("{}: {}".format(name, str(meter)))
107
        return self.delimiter.join(loss_str)
108
109
    def global_avg(self):
110
        loss_str = []
111
        for name, meter in self.meters.items():
112
            loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
        return self.delimiter.join(loss_str)
114
115
    def synchronize_between_processes(self):
116
        for meter in self.meters.values():
117
            meter.synchronize_between_processes()
118
119
    def add_meter(self, name, meter):
120
        self.meters[name] = meter
121
122
    def log_every(self, iterable, print_freq, header=None):
123
        i = 0
124
        if not header:
125
            header = ""
126
        start_time = time.time()
127
        end = time.time()
128
        iter_time = SmoothedValue(fmt="{avg:.4f}")
129
        data_time = SmoothedValue(fmt="{avg:.4f}")
130
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
        log_msg = [
132
            header,
133
            "[{0" + space_fmt + "}/{1}]",
134
            "eta: {eta}",
135
            "{meters}",
136
            "time: {time}",
137
            "data: {data}",
138
        ]
139
        if torch.cuda.is_available():
140
            log_msg.append("max mem: {memory:.0f}")
141
        log_msg = self.delimiter.join(log_msg)
142
        MB = 1024.0 * 1024.0
143
        for obj in iterable:
144
            data_time.update(time.time() - end)
145
            yield obj
146
            iter_time.update(time.time() - end)
147
            if i % print_freq == 0 or i == len(iterable) - 1:
148
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
                if torch.cuda.is_available():
151
                    print(
152
                        log_msg.format(
153
                            i,
154
                            len(iterable),
155
                            eta=eta_string,
156
                            meters=str(self),
157
                            time=str(iter_time),
158
                            data=str(data_time),
159
                            memory=torch.cuda.max_memory_allocated() / MB,
160
                        )
161
                    )
162
                else:
163
                    print(
164
                        log_msg.format(
165
                            i,
166
                            len(iterable),
167
                            eta=eta_string,
168
                            meters=str(self),
169
                            time=str(iter_time),
170
                            data=str(data_time),
171
                        )
172
                    )
173
            i += 1
174
            end = time.time()
175
        total_time = time.time() - start_time
176
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
        print(
178
            "{} Total time: {} ({:.4f} s / it)".format(
179
                header, total_time_str, total_time / len(iterable)
180
            )
181
        )
182
183
184
class AttrDict(dict):
185
    def __init__(self, *args, **kwargs):
186
        super(AttrDict, self).__init__(*args, **kwargs)
187
        self.__dict__ = self
188
189
190
def setup_logger():
191
    logging.basicConfig(
192
        level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
        format="%(asctime)s [%(levelname)s] %(message)s",
194
        handlers=[logging.StreamHandler()],
195
    )