a b/utils/torch_utils.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
PyTorch utils
4
"""
5
6
import math
7
import os
8
import platform
9
import subprocess
10
import time
11
import warnings
12
from contextlib import contextmanager
13
from copy import deepcopy
14
from pathlib import Path
15
16
import torch
17
import torch.distributed as dist
18
import torch.nn as nn
19
import torch.nn.functional as F
20
from torch.nn.parallel import DistributedDataParallel as DDP
21
22
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
23
24
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
25
RANK = int(os.getenv('RANK', -1))
26
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
27
28
try:
29
    import thop  # for FLOPs computation
30
except ImportError:
31
    thop = None
32
33
# Suppress PyTorch warnings
34
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
35
warnings.filterwarnings('ignore', category=UserWarning)
36
37
38
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
39
    # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
40
    def decorate(fn):
41
        return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
42
43
    return decorate
44
45
46
def smartCrossEntropyLoss(label_smoothing=0.0):
47
    # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
48
    if check_version(torch.__version__, '1.10.0'):
49
        return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
50
    if label_smoothing > 0:
51
        LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
52
    return nn.CrossEntropyLoss()
53
54
55
def smart_DDP(model):
56
    # Model DDP creation with checks
57
    assert not check_version(torch.__version__, '1.12.0', pinned=True), \
58
        'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
59
        'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
60
    if check_version(torch.__version__, '1.11.0'):
61
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
62
    else:
63
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
64
65
66
def reshape_classifier_output(model, n=1000):
67
    # Update a TorchVision classification model to class count 'n' if required
68
    from models.common import Classify
69
    name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1]  # last module
70
    if isinstance(m, Classify):  # YOLOv5 Classify() head
71
        if m.linear.out_features != n:
72
            m.linear = nn.Linear(m.linear.in_features, n)
73
    elif isinstance(m, nn.Linear):  # ResNet, EfficientNet
74
        if m.out_features != n:
75
            setattr(model, name, nn.Linear(m.in_features, n))
76
    elif isinstance(m, nn.Sequential):
77
        types = [type(x) for x in m]
78
        if nn.Linear in types:
79
            i = types.index(nn.Linear)  # nn.Linear index
80
            if m[i].out_features != n:
81
                m[i] = nn.Linear(m[i].in_features, n)
82
        elif nn.Conv2d in types:
83
            i = types.index(nn.Conv2d)  # nn.Conv2d index
84
            if m[i].out_channels != n:
85
                m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
86
87
88
@contextmanager
89
def torch_distributed_zero_first(local_rank: int):
90
    # Decorator to make all processes in distributed training wait for each local_master to do something
91
    if local_rank not in [-1, 0]:
92
        dist.barrier(device_ids=[local_rank])
93
    yield
94
    if local_rank == 0:
95
        dist.barrier(device_ids=[0])
96
97
98
def device_count():
99
    # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
100
    assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
101
    try:
102
        cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""'  # Windows
103
        return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
104
    except Exception:
105
        return 0
106
107
108
def select_device(device='', batch_size=0, newline=True):
109
    # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
110
    s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
111
    device = str(device).strip().lower().replace('cuda:', '').replace('none', '')  # to string, 'cuda:0' to '0'
112
    cpu = device == 'cpu'
113
    mps = device == 'mps'  # Apple Metal Performance Shaders (MPS)
114
    if cpu or mps:
115
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # force torch.cuda.is_available() = False
116
    elif device:  # non-cpu device requested
117
        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable - must be before assert is_available()
118
        assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
119
            f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
120
121
    if not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available
122
        devices = device.split(',') if device else '0'  # range(torch.cuda.device_count())  # i.e. 0,1,6,7
123
        n = len(devices)  # device count
124
        if n > 1 and batch_size > 0:  # check batch_size is divisible by device_count
125
            assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
126
        space = ' ' * (len(s) + 1)
127
        for i, d in enumerate(devices):
128
            p = torch.cuda.get_device_properties(i)
129
            s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n"  # bytes to MB
130
        arg = 'cuda:0'
131
    elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available():  # prefer MPS if available
132
        s += 'MPS\n'
133
        arg = 'mps'
134
    else:  # revert to CPU
135
        s += 'CPU\n'
136
        arg = 'cpu'
137
138
    if not newline:
139
        s = s.rstrip()
140
    LOGGER.info(s)
141
    return torch.device(arg)
142
143
144
def time_sync():
145
    # PyTorch-accurate time
146
    if torch.cuda.is_available():
147
        torch.cuda.synchronize()
148
    return time.time()
149
150
151
def profile(input, ops, n=10, device=None):
152
    """ YOLOv5 speed/memory/FLOPs profiler
153
    Usage:
154
        input = torch.randn(16, 3, 640, 640)
155
        m1 = lambda x: x * torch.sigmoid(x)
156
        m2 = nn.SiLU()
157
        profile(input, [m1, m2], n=100)  # profile over 100 iterations
158
    """
159
    results = []
160
    if not isinstance(device, torch.device):
161
        device = select_device(device)
162
    print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
163
          f"{'input':>24s}{'output':>24s}")
164
165
    for x in input if isinstance(input, list) else [input]:
166
        x = x.to(device)
167
        x.requires_grad = True
168
        for m in ops if isinstance(ops, list) else [ops]:
169
            m = m.to(device) if hasattr(m, 'to') else m  # device
170
            m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
171
            tf, tb, t = 0, 0, [0, 0, 0]  # dt forward, backward
172
            try:
173
                flops = thop.profile(m, inputs=(x, ), verbose=False)[0] / 1E9 * 2  # GFLOPs
174
            except Exception:
175
                flops = 0
176
177
            try:
178
                for _ in range(n):
179
                    t[0] = time_sync()
180
                    y = m(x)
181
                    t[1] = time_sync()
182
                    try:
183
                        _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
184
                        t[2] = time_sync()
185
                    except Exception:  # no backward method
186
                        # print(e)  # for debug
187
                        t[2] = float('nan')
188
                    tf += (t[1] - t[0]) * 1000 / n  # ms per op forward
189
                    tb += (t[2] - t[1]) * 1000 / n  # ms per op backward
190
                mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0  # (GB)
191
                s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y))  # shapes
192
                p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0  # parameters
193
                print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
194
                results.append([p, flops, mem, tf, tb, s_in, s_out])
195
            except Exception as e:
196
                print(e)
197
                results.append(None)
198
            torch.cuda.empty_cache()
199
    return results
200
201
202
def is_parallel(model):
203
    # Returns True if model is of type DP or DDP
204
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
205
206
207
def de_parallel(model):
208
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
209
    return model.module if is_parallel(model) else model
210
211
212
def initialize_weights(model):
213
    for m in model.modules():
214
        t = type(m)
215
        if t is nn.Conv2d:
216
            pass  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
217
        elif t is nn.BatchNorm2d:
218
            m.eps = 1e-3
219
            m.momentum = 0.03
220
        elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
221
            m.inplace = True
222
223
224
def find_modules(model, mclass=nn.Conv2d):
225
    # Finds layer indices matching module class 'mclass'
226
    return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
227
228
229
def sparsity(model):
230
    # Return global model sparsity
231
    a, b = 0, 0
232
    for p in model.parameters():
233
        a += p.numel()
234
        b += (p == 0).sum()
235
    return b / a
236
237
238
def prune(model, amount=0.3):
239
    # Prune model to requested global sparsity
240
    import torch.nn.utils.prune as prune
241
    for name, m in model.named_modules():
242
        if isinstance(m, nn.Conv2d):
243
            prune.l1_unstructured(m, name='weight', amount=amount)  # prune
244
            prune.remove(m, 'weight')  # make permanent
245
    LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
246
247
248
def fuse_conv_and_bn(conv, bn):
249
    # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
250
    fusedconv = nn.Conv2d(conv.in_channels,
251
                          conv.out_channels,
252
                          kernel_size=conv.kernel_size,
253
                          stride=conv.stride,
254
                          padding=conv.padding,
255
                          dilation=conv.dilation,
256
                          groups=conv.groups,
257
                          bias=True).requires_grad_(False).to(conv.weight.device)
258
259
    # Prepare filters
260
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
261
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
262
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
263
264
    # Prepare spatial bias
265
    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
266
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
267
    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
268
269
    return fusedconv
270
271
272
def model_info(model, verbose=False, imgsz=640):
273
    # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
274
    n_p = sum(x.numel() for x in model.parameters())  # number parameters
275
    n_g = sum(x.numel() for x in model.parameters() if x.requires_grad)  # number gradients
276
    if verbose:
277
        print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
278
        for i, (name, p) in enumerate(model.named_parameters()):
279
            name = name.replace('module_list.', '')
280
            print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
281
                  (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
282
283
    try:  # FLOPs
284
        p = next(model.parameters())
285
        stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32  # max stride
286
        im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # input image in BCHW format
287
        flops = thop.profile(deepcopy(model), inputs=(im, ), verbose=False)[0] / 1E9 * 2  # stride GFLOPs
288
        imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz]  # expand if int/float
289
        fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs'  # 640x640 GFLOPs
290
    except Exception:
291
        fs = ''
292
293
    name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
294
    LOGGER.info(f'{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
295
296
297
def scale_img(img, ratio=1.0, same_shape=False, gs=32):  # img(16,3,256,416)
298
    # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
299
    if ratio == 1.0:
300
        return img
301
    h, w = img.shape[2:]
302
    s = (int(h * ratio), int(w * ratio))  # new size
303
    img = F.interpolate(img, size=s, mode='bilinear', align_corners=False)  # resize
304
    if not same_shape:  # pad/crop img
305
        h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
306
    return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet mean
307
308
309
def copy_attr(a, b, include=(), exclude=()):
310
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
311
    for k, v in b.__dict__.items():
312
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
313
            continue
314
        else:
315
            setattr(a, k, v)
316
317
318
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
319
    # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
320
    g = [], [], []  # optimizer parameter groups
321
    bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
322
    for v in model.modules():
323
        for p_name, p in v.named_parameters(recurse=0):
324
            if p_name == 'bias':  # bias (no decay)
325
                g[2].append(p)
326
            elif p_name == 'weight' and isinstance(v, bn):  # weight (no decay)
327
                g[1].append(p)
328
            else:
329
                g[0].append(p)  # weight (with decay)
330
331
    if name == 'Adam':
332
        optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999))  # adjust beta1 to momentum
333
    elif name == 'AdamW':
334
        optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
335
    elif name == 'RMSProp':
336
        optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
337
    elif name == 'SGD':
338
        optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
339
    else:
340
        raise NotImplementedError(f'Optimizer {name} not implemented.')
341
342
    optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay
343
    optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)
344
    LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
345
                f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
346
    return optimizer
347
348
349
def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
350
    # YOLOv5 torch.hub.load() wrapper with smart error/issue handling
351
    if check_version(torch.__version__, '1.9.1'):
352
        kwargs['skip_validation'] = True  # validation causes GitHub API rate limit errors
353
    if check_version(torch.__version__, '1.12.0'):
354
        kwargs['trust_repo'] = True  # argument required starting in torch 0.12
355
    try:
356
        return torch.hub.load(repo, model, **kwargs)
357
    except Exception:
358
        return torch.hub.load(repo, model, force_reload=True, **kwargs)
359
360
361
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
362
    # Resume training from a partially trained checkpoint
363
    best_fitness = 0.0
364
    start_epoch = ckpt['epoch'] + 1
365
    if ckpt['optimizer'] is not None:
366
        optimizer.load_state_dict(ckpt['optimizer'])  # optimizer
367
        best_fitness = ckpt['best_fitness']
368
    if ema and ckpt.get('ema'):
369
        ema.ema.load_state_dict(ckpt['ema'].float().state_dict())  # EMA
370
        ema.updates = ckpt['updates']
371
    if resume:
372
        assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
373
                                f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
374
        LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
375
    if epochs < start_epoch:
376
        LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
377
        epochs += ckpt['epoch']  # finetune additional epochs
378
    return best_fitness, start_epoch, epochs
379
380
381
class EarlyStopping:
382
    # YOLOv5 simple early stopper
383
    def __init__(self, patience=30):
384
        self.best_fitness = 0.0  # i.e. mAP
385
        self.best_epoch = 0
386
        self.patience = patience or float('inf')  # epochs to wait after fitness stops improving to stop
387
        self.possible_stop = False  # possible stop may occur next epoch
388
389
    def __call__(self, epoch, fitness):
390
        if fitness >= self.best_fitness:  # >= 0 to allow for early zero-fitness stage of training
391
            self.best_epoch = epoch
392
            self.best_fitness = fitness
393
        delta = epoch - self.best_epoch  # epochs without improvement
394
        self.possible_stop = delta >= (self.patience - 1)  # possible stop may occur next epoch
395
        stop = delta >= self.patience  # stop training if patience exceeded
396
        if stop:
397
            LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
398
                        f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
399
                        f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
400
                        f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
401
        return stop
402
403
404
class ModelEMA:
405
    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
406
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
407
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
408
    """
409
410
    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
411
        # Create EMA
412
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
413
        self.updates = updates  # number of EMA updates
414
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
415
        for p in self.ema.parameters():
416
            p.requires_grad_(False)
417
418
    def update(self, model):
419
        # Update EMA parameters
420
        self.updates += 1
421
        d = self.decay(self.updates)
422
423
        msd = de_parallel(model).state_dict()  # model state_dict
424
        for k, v in self.ema.state_dict().items():
425
            if v.dtype.is_floating_point:  # true for FP16 and FP32
426
                v *= d
427
                v += (1 - d) * msd[k].detach()
428
        # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
429
430
    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
431
        # Update EMA attributes
432
        copy_attr(self.ema, model, include, exclude)