Switch to unified view

a b/yolov5/utils/torch_utils.py
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
PyTorch utils
4
"""
5
6
import datetime
7
import math
8
import os
9
import platform
10
import subprocess
11
import time
12
from contextlib import contextmanager
13
from copy import deepcopy
14
from pathlib import Path
15
16
import torch
17
import torch.distributed as dist
18
import torch.nn as nn
19
import torch.nn.functional as F
20
21
from utils.general import LOGGER
22
23
try:
24
    import thop  # for FLOPs computation
25
except ImportError:
26
    thop = None
27
28
29
@contextmanager
30
def torch_distributed_zero_first(local_rank: int):
31
    """
32
    Decorator to make all processes in distributed training wait for each local_master to do something.
33
    """
34
    if local_rank not in [-1, 0]:
35
        dist.barrier(device_ids=[local_rank])
36
    yield
37
    if local_rank == 0:
38
        dist.barrier(device_ids=[0])
39
40
41
def date_modified(path=__file__):
42
    # return human-readable file modification date, i.e. '2021-3-26'
43
    t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
44
    return f'{t.year}-{t.month}-{t.day}'
45
46
47
def git_describe(path=Path(__file__).parent):  # path must be a directory
48
    # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
49
    s = f'git -C {path} describe --tags --long --always'
50
    try:
51
        return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
52
    except subprocess.CalledProcessError as e:
53
        return ''  # not a git repository
54
55
56
def select_device(device='', batch_size=None, newline=True):
57
    # device = 'cpu' or '0' or '0,1,2,3'
58
    s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} '  # string
59
    device = str(device).strip().lower().replace('cuda:', '')  # to string, 'cuda:0' to '0'
60
    cpu = device == 'cpu'
61
    if cpu:
62
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # force torch.cuda.is_available() = False
63
    elif device:  # non-cpu device requested
64
        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable
65
        assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested'  # check availability
66
67
    cuda = not cpu and torch.cuda.is_available()
68
    if cuda:
69
        devices = device.split(',') if device else '0'  # range(torch.cuda.device_count())  # i.e. 0,1,6,7
70
        n = len(devices)  # device count
71
        if n > 1 and batch_size:  # check batch_size is divisible by device_count
72
            assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
73
        space = ' ' * (len(s) + 1)
74
        for i, d in enumerate(devices):
75
            p = torch.cuda.get_device_properties(i)
76
            s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MiB)\n"  # bytes to MB
77
    else:
78
        s += 'CPU\n'
79
80
    if not newline:
81
        s = s.rstrip()
82
    LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)  # emoji-safe
83
    return torch.device('cuda:0' if cuda else 'cpu')
84
85
86
def time_sync():
87
    # pytorch-accurate time
88
    if torch.cuda.is_available():
89
        torch.cuda.synchronize()
90
    return time.time()
91
92
93
def profile(input, ops, n=10, device=None):
94
    # YOLOv5 speed/memory/FLOPs profiler
95
    #
96
    # Usage:
97
    #     input = torch.randn(16, 3, 640, 640)
98
    #     m1 = lambda x: x * torch.sigmoid(x)
99
    #     m2 = nn.SiLU()
100
    #     profile(input, [m1, m2], n=100)  # profile over 100 iterations
101
102
    results = []
103
    device = device or select_device()
104
    print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
105
          f"{'input':>24s}{'output':>24s}")
106
107
    for x in input if isinstance(input, list) else [input]:
108
        x = x.to(device)
109
        x.requires_grad = True
110
        for m in ops if isinstance(ops, list) else [ops]:
111
            m = m.to(device) if hasattr(m, 'to') else m  # device
112
            m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
113
            tf, tb, t = 0, 0, [0, 0, 0]  # dt forward, backward
114
            try:
115
                flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2  # GFLOPs
116
            except:
117
                flops = 0
118
119
            try:
120
                for _ in range(n):
121
                    t[0] = time_sync()
122
                    y = m(x)
123
                    t[1] = time_sync()
124
                    try:
125
                        _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
126
                        t[2] = time_sync()
127
                    except Exception as e:  # no backward method
128
                        # print(e)  # for debug
129
                        t[2] = float('nan')
130
                    tf += (t[1] - t[0]) * 1000 / n  # ms per op forward
131
                    tb += (t[2] - t[1]) * 1000 / n  # ms per op backward
132
                mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0  # (GB)
133
                s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
134
                s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
135
                p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0  # parameters
136
                print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
137
                results.append([p, flops, mem, tf, tb, s_in, s_out])
138
            except Exception as e:
139
                print(e)
140
                results.append(None)
141
            torch.cuda.empty_cache()
142
    return results
143
144
145
def is_parallel(model):
146
    # Returns True if model is of type DP or DDP
147
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
148
149
150
def de_parallel(model):
151
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
152
    return model.module if is_parallel(model) else model
153
154
155
def initialize_weights(model):
156
    for m in model.modules():
157
        t = type(m)
158
        if t is nn.Conv2d:
159
            pass  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
160
        elif t is nn.BatchNorm2d:
161
            m.eps = 1e-3
162
            m.momentum = 0.03
163
        elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
164
            m.inplace = True
165
166
167
def find_modules(model, mclass=nn.Conv2d):
168
    # Finds layer indices matching module class 'mclass'
169
    return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
170
171
172
def sparsity(model):
173
    # Return global model sparsity
174
    a, b = 0, 0
175
    for p in model.parameters():
176
        a += p.numel()
177
        b += (p == 0).sum()
178
    return b / a
179
180
181
def prune(model, amount=0.3):
182
    # Prune model to requested global sparsity
183
    import torch.nn.utils.prune as prune
184
    print('Pruning model... ', end='')
185
    for name, m in model.named_modules():
186
        if isinstance(m, nn.Conv2d):
187
            prune.l1_unstructured(m, name='weight', amount=amount)  # prune
188
            prune.remove(m, 'weight')  # make permanent
189
    print(' %.3g global sparsity' % sparsity(model))
190
191
192
def fuse_conv_and_bn(conv, bn):
193
    # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
194
    fusedconv = nn.Conv2d(conv.in_channels,
195
                          conv.out_channels,
196
                          kernel_size=conv.kernel_size,
197
                          stride=conv.stride,
198
                          padding=conv.padding,
199
                          groups=conv.groups,
200
                          bias=True).requires_grad_(False).to(conv.weight.device)
201
202
    # prepare filters
203
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
204
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
205
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
206
207
    # prepare spatial bias
208
    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
209
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
210
    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
211
212
    return fusedconv
213
214
215
def model_info(model, verbose=False, img_size=640):
216
    # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
217
    n_p = sum(x.numel() for x in model.parameters())  # number parameters
218
    n_g = sum(x.numel() for x in model.parameters() if x.requires_grad)  # number gradients
219
    if verbose:
220
        print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
221
        for i, (name, p) in enumerate(model.named_parameters()):
222
            name = name.replace('module_list.', '')
223
            print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
224
                  (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
225
226
    try:  # FLOPs
227
        from thop import profile
228
        stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
229
        img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device)  # input
230
        flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2  # stride GFLOPs
231
        img_size = img_size if isinstance(img_size, list) else [img_size, img_size]  # expand if int/float
232
        fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride)  # 640x640 GFLOPs
233
    except (ImportError, Exception):
234
        fs = ''
235
236
    LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
237
238
239
def scale_img(img, ratio=1.0, same_shape=False, gs=32):  # img(16,3,256,416)
240
    # scales img(bs,3,y,x) by ratio constrained to gs-multiple
241
    if ratio == 1.0:
242
        return img
243
    else:
244
        h, w = img.shape[2:]
245
        s = (int(h * ratio), int(w * ratio))  # new size
246
        img = F.interpolate(img, size=s, mode='bilinear', align_corners=False)  # resize
247
        if not same_shape:  # pad/crop img
248
            h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
249
        return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet mean
250
251
252
def copy_attr(a, b, include=(), exclude=()):
253
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
254
    for k, v in b.__dict__.items():
255
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
256
            continue
257
        else:
258
            setattr(a, k, v)
259
260
261
class EarlyStopping:
262
    # YOLOv5 simple early stopper
263
    def __init__(self, patience=30):
264
        self.best_fitness = 0.0  # i.e. mAP
265
        self.best_epoch = 0
266
        self.patience = patience or float('inf')  # epochs to wait after fitness stops improving to stop
267
        self.possible_stop = False  # possible stop may occur next epoch
268
269
    def __call__(self, epoch, fitness):
270
        if fitness >= self.best_fitness:  # >= 0 to allow for early zero-fitness stage of training
271
            self.best_epoch = epoch
272
            self.best_fitness = fitness
273
        delta = epoch - self.best_epoch  # epochs without improvement
274
        self.possible_stop = delta >= (self.patience - 1)  # possible stop may occur next epoch
275
        stop = delta >= self.patience  # stop training if patience exceeded
276
        if stop:
277
            LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
278
                        f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
279
                        f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
280
                        f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
281
        return stop
282
283
284
class ModelEMA:
285
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
286
    Keep a moving average of everything in the model state_dict (parameters and buffers).
287
    This is intended to allow functionality like
288
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
289
    A smoothed version of the weights is necessary for some training schemes to perform well.
290
    This class is sensitive where it is initialized in the sequence of model init,
291
    GPU assignment and distributed training wrappers.
292
    """
293
294
    def __init__(self, model, decay=0.9999, updates=0):
295
        # Create EMA
296
        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()  # FP32 EMA
297
        # if next(model.parameters()).device.type != 'cpu':
298
        #     self.ema.half()  # FP16 EMA
299
        self.updates = updates  # number of EMA updates
300
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))  # decay exponential ramp (to help early epochs)
301
        for p in self.ema.parameters():
302
            p.requires_grad_(False)
303
304
    def update(self, model):
305
        # Update EMA parameters
306
        with torch.no_grad():
307
            self.updates += 1
308
            d = self.decay(self.updates)
309
310
            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict
311
            for k, v in self.ema.state_dict().items():
312
                if v.dtype.is_floating_point:
313
                    v *= d
314
                    v += (1 - d) * msd[k].detach()
315
316
    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
317
        # Update EMA attributes
318
        copy_attr(self.ema, model, include, exclude)