Diff of /Cluster-ViT/util/misc.py [000000] .. [15fc01]

Switch to unified view

a b/Cluster-ViT/util/misc.py
1
"""
2
Misc functions, including distributed helpers.
3
4
Mostly copy-paste from torchvision references.
5
"""
6
import os
7
import subprocess
8
import time
9
from collections import defaultdict, deque
10
import datetime
11
import pickle
12
from typing import Optional, List
13
14
import torch
15
import torch.distributed as dist
16
from torch import Tensor
17
18
# needed due to empty tensor bug in pytorch and torchvision 0.5
19
import torchvision
20
if float(torchvision.__version__[:3]) < 0.7:
21
    from torchvision.ops import _new_empty_tensor
22
    from torchvision.ops.misc import _output_size
23
24
25
class SmoothedValue(object):
26
    """Track a series of values and provide access to smoothed values over a
27
    window or the global series average.
28
    """
29
30
    def __init__(self, window_size=20, fmt=None):
31
        if fmt is None:
32
            fmt = "{median:.4f} ({global_avg:.4f})"
33
        self.deque = deque(maxlen=window_size)
34
        self.total = 0.0
35
        self.count = 0
36
        self.fmt = fmt
37
38
    def update(self, value, n=1):
39
        self.deque.append(value)
40
        self.count += n
41
        self.total += value * n
42
43
    def synchronize_between_processes(self):
44
        """
45
        Warning: does not synchronize the deque!
46
        """
47
        if not is_dist_avail_and_initialized():
48
            return
49
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
50
        dist.barrier()
51
        dist.all_reduce(t)
52
        t = t.tolist()
53
        self.count = int(t[0])
54
        self.total = t[1]
55
56
    @property
57
    def median(self):
58
        d = torch.tensor(list(self.deque))
59
        return d.median().item()
60
61
    @property
62
    def avg(self):
63
        d = torch.tensor(list(self.deque), dtype=torch.float32)
64
        return d.mean().item()
65
66
    @property
67
    def global_avg(self):
68
        return self.total / self.count
69
70
    @property
71
    def max(self):
72
        return max(self.deque)
73
74
    @property
75
    def value(self):
76
        return self.deque[-1]
77
78
    def __str__(self):
79
        return self.fmt.format(
80
            median=self.median,
81
            avg=self.avg,
82
            global_avg=self.global_avg,
83
            max=self.max,
84
            value=self.value)
85
86
87
def all_gather(data):
88
    """
89
    Run all_gather on arbitrary picklable data (not necessarily tensors)
90
    Args:
91
        data: any picklable object
92
    Returns:
93
        list[data]: list of data gathered from each rank
94
    """
95
    world_size = get_world_size()
96
    if world_size == 1:
97
        return [data]
98
99
    # serialized to a Tensor
100
    buffer = pickle.dumps(data)
101
    storage = torch.ByteStorage.from_buffer(buffer)
102
    tensor = torch.ByteTensor(storage).to("cuda")
103
104
    # obtain Tensor size of each rank
105
    local_size = torch.tensor([tensor.numel()], device="cuda")
106
    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
107
    dist.all_gather(size_list, local_size)
108
    size_list = [int(size.item()) for size in size_list]
109
    max_size = max(size_list)
110
111
    # receiving Tensor from all ranks
112
    # we pad the tensor because torch all_gather does not support
113
    # gathering tensors of different shapes
114
    tensor_list = []
115
    for _ in size_list:
116
        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
117
    if local_size != max_size:
118
        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
119
        tensor = torch.cat((tensor, padding), dim=0)
120
    dist.all_gather(tensor_list, tensor)
121
122
    data_list = []
123
    for size, tensor in zip(size_list, tensor_list):
124
        buffer = tensor.cpu().numpy().tobytes()[:size]
125
        data_list.append(pickle.loads(buffer))
126
127
    return data_list
128
129
130
def reduce_dict(input_dict, average=True):
131
    """
132
    Args:
133
        input_dict (dict): all the values will be reduced
134
        average (bool): whether to do average or sum
135
    Reduce the values in the dictionary from all processes so that all processes
136
    have the averaged results. Returns a dict with the same fields as
137
    input_dict, after reduction.
138
    """
139
    world_size = get_world_size()
140
    if world_size < 2:
141
        return input_dict
142
    with torch.no_grad():
143
        names = []
144
        values = []
145
        # sort the keys so that they are consistent across processes
146
        for k in sorted(input_dict.keys()):
147
            names.append(k)
148
            values.append(input_dict[k])
149
        values = torch.stack(values, dim=0)
150
        dist.all_reduce(values)
151
        if average:
152
            values /= world_size
153
        reduced_dict = {k: v for k, v in zip(names, values)}
154
    return reduced_dict
155
156
157
class MetricLogger(object):
158
    def __init__(self, delimiter="\t"):
159
        self.meters = defaultdict(SmoothedValue)
160
        self.delimiter = delimiter
161
162
    def update(self, **kwargs):
163
        for k, v in kwargs.items():
164
            if isinstance(v, torch.Tensor):
165
                v = v.item()
166
            assert isinstance(v, (float, int))
167
            self.meters[k].update(v)
168
169
    def __getattr__(self, attr):
170
        if attr in self.meters:
171
            return self.meters[attr]
172
        if attr in self.__dict__:
173
            return self.__dict__[attr]
174
        raise AttributeError("'{}' object has no attribute '{}'".format(
175
            type(self).__name__, attr))
176
177
    def __str__(self):
178
        loss_str = []
179
        for name, meter in self.meters.items():
180
            loss_str.append(
181
                "{}: {}".format(name, str(meter))
182
            )
183
        return self.delimiter.join(loss_str)
184
185
    def synchronize_between_processes(self):
186
        for meter in self.meters.values():
187
            meter.synchronize_between_processes()
188
189
    def add_meter(self, name, meter):
190
        self.meters[name] = meter
191
192
    def log_every(self, iterable, print_freq, header=None):
193
        i = 0
194
        if not header:
195
            header = ''
196
        start_time = time.time()
197
        end = time.time()
198
        iter_time = SmoothedValue(fmt='{avg:.4f}')
199
        data_time = SmoothedValue(fmt='{avg:.4f}')
200
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
201
        if torch.cuda.is_available():
202
            log_msg = self.delimiter.join([
203
                header,
204
                '[{0' + space_fmt + '}/{1}]',
205
                'eta: {eta}',
206
                '{meters}',
207
                'time: {time}',
208
                'data: {data}',
209
                'max mem: {memory:.0f}'
210
            ])
211
        else:
212
            log_msg = self.delimiter.join([
213
                header,
214
                '[{0' + space_fmt + '}/{1}]',
215
                'eta: {eta}',
216
                '{meters}',
217
                'time: {time}',
218
                'data: {data}'
219
            ])
220
        MB = 1024.0 * 1024.0
221
   
222
        for obj in enumerate(iterable):
223
            data_time.update(time.time() - end)
224
            yield obj
225
            iter_time.update(time.time() - end)
226
            if i % print_freq == 0 or i == len(iterable) - 1:
227
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
228
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
229
                if torch.cuda.is_available():
230
                    print(log_msg.format(
231
                        i, len(iterable), eta=eta_string,
232
                        meters=str(self),
233
                        time=str(iter_time), data=str(data_time),
234
                        memory=torch.cuda.max_memory_allocated() / MB))
235
                else:
236
                    print(log_msg.format(
237
                        i, len(iterable), eta=eta_string,
238
                        meters=str(self),
239
                        time=str(iter_time), data=str(data_time)))
240
            i += 1
241
            end = time.time()
242
        total_time = time.time() - start_time
243
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
244
        print('{} Total time: {} ({:.4f} s / it)'.format(
245
            header, total_time_str, total_time / len(iterable)))
246
247
248
def get_sha():
249
    cwd = os.path.dirname(os.path.abspath(__file__))
250
251
    def _run(command):
252
        return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
253
    sha = 'N/A'
254
    diff = "clean"
255
    branch = 'N/A'
256
    try:
257
        sha = _run(['git', 'rev-parse', 'HEAD'])
258
        subprocess.check_output(['git', 'diff'], cwd=cwd)
259
        diff = _run(['git', 'diff-index', 'HEAD'])
260
        diff = "has uncommited changes" if diff else "clean"
261
        branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
262
    except Exception:
263
        pass
264
    message = f"sha: {sha}, status: {diff}, branch: {branch}"
265
    return message
266
267
268
def collate_fn(batch):
269
    batch = list(zip(*batch))
270
    batch[0] = nested_tensor_from_tensor_list(batch[0])
271
    return tuple(batch)
272
273
274
def _max_by_axis(the_list):
275
    # type: (List[List[int]]) -> List[int]
276
    maxes = the_list[0]
277
    for sublist in the_list[1:]:
278
        for index, item in enumerate(sublist):
279
            maxes[index] = max(maxes[index], item)
280
    return maxes
281
282
283
class NestedTensor(object):
284
    def __init__(self, tensors, mask: Optional[Tensor]):
285
        self.tensors = tensors
286
        self.mask = mask
287
288
    def to(self, device):
289
        # type: (Device) -> NestedTensor # noqa
290
        cast_tensor = self.tensors.to(device)
291
        mask = self.mask
292
        if mask is not None:
293
            assert mask is not None
294
            cast_mask = mask.to(device)
295
        else:
296
            cast_mask = None
297
        return NestedTensor(cast_tensor, cast_mask)
298
299
    def decompose(self):
300
        return self.tensors, self.mask
301
302
    def __repr__(self):
303
        return str(self.tensors)
304
305
306
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
307
    # TODO make this more general
308
    if tensor_list[0].ndim == 3:
309
        if torchvision._is_tracing():
310
            # nested_tensor_from_tensor_list() does not export well to ONNX
311
            # call _onnx_nested_tensor_from_tensor_list() instead
312
            return _onnx_nested_tensor_from_tensor_list(tensor_list)
313
314
        # TODO make it support different-sized images
315
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
316
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
317
        batch_shape = [len(tensor_list)] + max_size
318
        b, c, h, w = batch_shape
319
        dtype = tensor_list[0].dtype
320
        device = tensor_list[0].device
321
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
322
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
323
        for img, pad_img, m in zip(tensor_list, tensor, mask):
324
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
325
            m[: img.shape[1], :img.shape[2]] = False
326
    else:
327
        raise ValueError('not supported')
328
    return NestedTensor(tensor, mask)
329
330
331
# _onnx_nested_tensor_from_tensor_list() is an implementation of
332
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
333
@torch.jit.unused
334
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
335
    max_size = []
336
    for i in range(tensor_list[0].dim()):
337
        max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
338
        max_size.append(max_size_i)
339
    max_size = tuple(max_size)
340
341
    # work around for
342
    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
343
    # m[: img.shape[1], :img.shape[2]] = False
344
    # which is not yet supported in onnx
345
    padded_imgs = []
346
    padded_masks = []
347
    for img in tensor_list:
348
        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
349
        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
350
        padded_imgs.append(padded_img)
351
352
        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
353
        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
354
        padded_masks.append(padded_mask.to(torch.bool))
355
356
    tensor = torch.stack(padded_imgs)
357
    mask = torch.stack(padded_masks)
358
359
    return NestedTensor(tensor, mask=mask)
360
361
362
def setup_for_distributed(is_master):
363
    """
364
    This function disables printing when not in master process
365
    """
366
    import builtins as __builtin__
367
    builtin_print = __builtin__.print
368
369
    def print(*args, **kwargs):
370
        force = kwargs.pop('force', False)
371
        if is_master or force:
372
            builtin_print(*args, **kwargs)
373
374
    __builtin__.print = print
375
376
377
def is_dist_avail_and_initialized():
378
    if not dist.is_available():
379
        return False
380
    if not dist.is_initialized():
381
        return False
382
    return True
383
384
385
def get_world_size():
386
    if not is_dist_avail_and_initialized():
387
        return 1
388
    return dist.get_world_size()
389
390
391
def get_rank():
392
    if not is_dist_avail_and_initialized():
393
        return 0
394
    return dist.get_rank()
395
396
397
def is_main_process():
398
    return get_rank() == 0
399
400
401
def save_on_master(*args, **kwargs):
402
    if is_main_process():
403
        torch.save(*args, **kwargs)
404
405
406
def init_distributed_mode(args):
407
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
408
        args.rank = int(os.environ["RANK"])
409
        args.world_size = int(os.environ['WORLD_SIZE'])
410
        args.gpu = int(os.environ['LOCAL_RANK'])
411
    elif 'SLURM_PROCID' in os.environ:
412
        args.rank = int(os.environ['SLURM_PROCID'])
413
        args.gpu = args.rank % torch.cuda.device_count()
414
    else:
415
        print('Not using distributed mode')
416
        args.distributed = False
417
        return
418
419
    args.distributed = True
420
421
    torch.cuda.set_device(args.gpu)
422
    args.dist_backend = 'nccl'
423
    print('| distributed init (rank {}): {}'.format(
424
        args.rank, args.dist_url), flush=True)
425
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
426
                                         world_size=args.world_size, rank=args.rank)
427
    torch.distributed.barrier()
428
    setup_for_distributed(args.rank == 0)
429
430
431
@torch.no_grad()
432
def accuracy(output, target, topk=(1,)):
433
    """Computes the precision@k for the specified values of k"""
434
    if target.numel() == 0:
435
        return [torch.zeros([], device=output.device)]
436
    maxk = max(topk)
437
    batch_size = target.size(0)
438
439
    _, pred = output.topk(maxk, 1, True, True)
440
    pred = pred.t()
441
    correct = pred.eq(target.view(1, -1).expand_as(pred))
442
443
    res = []
444
    for k in topk:
445
        correct_k = correct[:k].view(-1).float().sum(0)
446
        res.append(correct_k.mul_(100.0 / batch_size))
447
    return res
448
449
450
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
451
    # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
452
    """
453
    Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
454
    This will eventually be supported natively by PyTorch, and this
455
    class can go away.
456
    """
457
    if float(torchvision.__version__[:3]) < 0.7:
458
        if input.numel() > 0:
459
            return torch.nn.functional.interpolate(
460
                input, size, scale_factor, mode, align_corners
461
            )
462
463
        output_shape = _output_size(2, input, size, scale_factor)
464
        output_shape = list(input.shape[:-2]) + list(output_shape)
465
        return _new_empty_tensor(input, output_shape)
466
    else:
467
        return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)