Diff of /yolov5/utils/general.py [000000] .. [f26a44]

Switch to unified view

a b/yolov5/utils/general.py
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
General utils
4
"""
5
6
import contextlib
7
import glob
8
import logging
9
import math
10
import os
11
import platform
12
import random
13
import re
14
import shutil
15
import signal
16
import time
17
import urllib
18
from itertools import repeat
19
from multiprocessing.pool import ThreadPool
20
from pathlib import Path
21
from subprocess import check_output
22
from zipfile import ZipFile
23
24
import cv2
25
import numpy as np
26
import pandas as pd
27
import pkg_resources as pkg
28
import torch
29
import torchvision
30
import yaml
31
32
from utils.downloads import gsutil_getsize
33
from utils.metrics import box_iou, fitness
34
35
# Settings
36
torch.set_printoptions(linewidth=320, precision=5, profile='long')
37
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5
38
pd.options.display.max_columns = 10
39
cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
40
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8))  # NumExpr max threads
41
42
FILE = Path(__file__).resolve()
43
ROOT = FILE.parents[1]  # YOLOv5 root directory
44
45
46
def set_logging(name=None, verbose=True):
47
    # Sets level and returns logger
48
    rank = int(os.getenv('RANK', -1))  # rank in world for Multi-GPU trainings
49
    logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
50
    return logging.getLogger(name)
51
52
53
LOGGER = set_logging(__name__)  # define globally (used in train.py, val.py, detect.py, etc.)
54
55
56
class Profile(contextlib.ContextDecorator):
57
    # Usage: @Profile() decorator or 'with Profile():' context manager
58
    def __enter__(self):
59
        self.start = time.time()
60
61
    def __exit__(self, type, value, traceback):
62
        print(f'Profile results: {time.time() - self.start:.5f}s')
63
64
65
class Timeout(contextlib.ContextDecorator):
66
    # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
67
    def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
68
        self.seconds = int(seconds)
69
        self.timeout_message = timeout_msg
70
        self.suppress = bool(suppress_timeout_errors)
71
72
    def _timeout_handler(self, signum, frame):
73
        raise TimeoutError(self.timeout_message)
74
75
    def __enter__(self):
76
        signal.signal(signal.SIGALRM, self._timeout_handler)  # Set handler for SIGALRM
77
        signal.alarm(self.seconds)  # start countdown for SIGALRM to be raised
78
79
    def __exit__(self, exc_type, exc_val, exc_tb):
80
        signal.alarm(0)  # Cancel SIGALRM if it's scheduled
81
        if self.suppress and exc_type is TimeoutError:  # Suppress TimeoutError
82
            return True
83
84
85
class WorkingDirectory(contextlib.ContextDecorator):
86
    # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
87
    def __init__(self, new_dir):
88
        self.dir = new_dir  # new dir
89
        self.cwd = Path.cwd().resolve()  # current dir
90
91
    def __enter__(self):
92
        os.chdir(self.dir)
93
94
    def __exit__(self, exc_type, exc_val, exc_tb):
95
        os.chdir(self.cwd)
96
97
98
def try_except(func):
99
    # try-except function. Usage: @try_except decorator
100
    def handler(*args, **kwargs):
101
        try:
102
            func(*args, **kwargs)
103
        except Exception as e:
104
            print(e)
105
106
    return handler
107
108
109
def methods(instance):
110
    # Get class/instance methods
111
    return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
112
113
114
def print_args(name, opt):
115
    # Print argparser arguments
116
    LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
117
118
119
def init_seeds(seed=0):
120
    # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
121
    # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
122
    import torch.backends.cudnn as cudnn
123
    random.seed(seed)
124
    np.random.seed(seed)
125
    torch.manual_seed(seed)
126
    cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
127
128
129
def intersect_dicts(da, db, exclude=()):
130
    # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
131
    return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
132
133
134
def get_latest_run(search_dir='.'):
135
    # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
136
    last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
137
    return max(last_list, key=os.path.getctime) if last_list else ''
138
139
140
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
141
    # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
142
    env = os.getenv(env_var)
143
    if env:
144
        path = Path(env)  # use environment variable
145
    else:
146
        cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'}  # 3 OS dirs
147
        path = Path.home() / cfg.get(platform.system(), '')  # OS-specific config dir
148
        path = (path if is_writeable(path) else Path('/tmp')) / dir  # GCP and AWS lambda fix, only /tmp is writeable
149
    path.mkdir(exist_ok=True)  # make if required
150
    return path
151
152
153
def is_writeable(dir, test=False):
154
    # Return True if directory has write permissions, test opening a file with write permissions if test=True
155
    if test:  # method 1
156
        file = Path(dir) / 'tmp.txt'
157
        try:
158
            with open(file, 'w'):  # open file with write permissions
159
                pass
160
            file.unlink()  # remove file
161
            return True
162
        except OSError:
163
            return False
164
    else:  # method 2
165
        return os.access(dir, os.R_OK)  # possible issues on Windows
166
167
168
def is_docker():
169
    # Is environment a Docker container?
170
    return Path('/workspace').exists()  # or Path('/.dockerenv').exists()
171
172
173
def is_colab():
174
    # Is environment a Google Colab instance?
175
    try:
176
        import google.colab
177
        return True
178
    except ImportError:
179
        return False
180
181
182
def is_pip():
183
    # Is file in a pip package?
184
    return 'site-packages' in Path(__file__).resolve().parts
185
186
187
def is_ascii(s=''):
188
    # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
189
    s = str(s)  # convert list, tuple, None, etc. to str
190
    return len(s.encode().decode('ascii', 'ignore')) == len(s)
191
192
193
def is_chinese(s='人工智能'):
194
    # Is string composed of any Chinese characters?
195
    return re.search('[\u4e00-\u9fff]', s)
196
197
198
def emojis(str=''):
199
    # Return platform-dependent emoji-safe version of string
200
    return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
201
202
203
def file_size(path):
204
    # Return file/dir size (MB)
205
    path = Path(path)
206
    if path.is_file():
207
        return path.stat().st_size / 1E6
208
    elif path.is_dir():
209
        return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
210
    else:
211
        return 0.0
212
213
214
def check_online():
215
    # Check internet connectivity
216
    import socket
217
    try:
218
        socket.create_connection(("1.1.1.1", 443), 5)  # check host accessibility
219
        return True
220
    except OSError:
221
        return False
222
223
224
@try_except
225
@WorkingDirectory(ROOT)
226
def check_git_status():
227
    # Recommend 'git pull' if code is out of date
228
    msg = ', for updates see https://github.com/ultralytics/yolov5'
229
    print(colorstr('github: '), end='')
230
    assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
231
    assert not is_docker(), 'skipping check (Docker image)' + msg
232
    assert check_online(), 'skipping check (offline)' + msg
233
234
    cmd = 'git fetch && git config --get remote.origin.url'
235
    url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git')  # git fetch
236
    branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip()  # checked out
237
    n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True))  # commits behind
238
    if n > 0:
239
        s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
240
    else:
241
        s = f'up to date with {url} ✅'
242
    print(emojis(s))  # emoji-safe
243
244
245
def check_python(minimum='3.6.2'):
246
    # Check current python version vs. required python version
247
    check_version(platform.python_version(), minimum, name='Python ', hard=True)
248
249
250
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
251
    # Check version vs. required version
252
    current, minimum = (pkg.parse_version(x) for x in (current, minimum))
253
    result = (current == minimum) if pinned else (current >= minimum)  # bool
254
    if hard:  # assert min requirements met
255
        assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
256
    else:
257
        return result
258
259
260
@try_except
261
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
262
    # Check installed dependencies meet requirements (pass *.txt file or list of packages)
263
    prefix = colorstr('red', 'bold', 'requirements:')
264
    check_python()  # check python version
265
    if isinstance(requirements, (str, Path)):  # requirements.txt file
266
        file = Path(requirements)
267
        assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
268
        with file.open() as f:
269
            requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
270
    else:  # list or tuple of packages
271
        requirements = [x for x in requirements if x not in exclude]
272
273
    n = 0  # number of packages updates
274
    for r in requirements:
275
        try:
276
            pkg.require(r)
277
        except Exception as e:  # DistributionNotFound or VersionConflict if requirements not met
278
            s = f"{prefix} {r} not found and is required by YOLOv5"
279
            if install:
280
                print(f"{s}, attempting auto-update...")
281
                try:
282
                    assert check_online(), f"'pip install {r}' skipped (offline)"
283
                    print(check_output(f"pip install '{r}'", shell=True).decode())
284
                    n += 1
285
                except Exception as e:
286
                    print(f'{prefix} {e}')
287
            else:
288
                print(f'{s}. Please install and rerun your command.')
289
290
    if n:  # if packages updated
291
        source = file.resolve() if 'file' in locals() else requirements
292
        s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
293
            f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
294
        print(emojis(s))
295
296
297
def check_img_size(imgsz, s=32, floor=0):
298
    # Verify image size is a multiple of stride s in each dimension
299
    if isinstance(imgsz, int):  # integer i.e. img_size=640
300
        new_size = max(make_divisible(imgsz, int(s)), floor)
301
    else:  # list i.e. img_size=[640, 480]
302
        new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
303
    if new_size != imgsz:
304
        print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
305
    return new_size
306
307
308
def check_imshow():
309
    # Check if environment supports image displays
310
    try:
311
        assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
312
        assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
313
        cv2.imshow('test', np.zeros((1, 1, 3)))
314
        cv2.waitKey(1)
315
        cv2.destroyAllWindows()
316
        cv2.waitKey(1)
317
        return True
318
    except Exception as e:
319
        print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
320
        return False
321
322
323
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
324
    # Check file(s) for acceptable suffix
325
    if file and suffix:
326
        if isinstance(suffix, str):
327
            suffix = [suffix]
328
        for f in file if isinstance(file, (list, tuple)) else [file]:
329
            s = Path(f).suffix.lower()  # file suffix
330
            if len(s):
331
                assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
332
333
334
def check_yaml(file, suffix=('.yaml', '.yml')):
335
    # Search/download YAML file (if necessary) and return path, checking suffix
336
    return check_file(file, suffix)
337
338
339
def check_file(file, suffix=''):
340
    # Search/download file (if necessary) and return path
341
    check_suffix(file, suffix)  # optional
342
    file = str(file)  # convert to str()
343
    if Path(file).is_file() or file == '':  # exists
344
        return file
345
    elif file.startswith(('http:/', 'https:/')):  # download
346
        url = str(Path(file)).replace(':/', '://')  # Pathlib turns :// -> :/
347
        file = Path(urllib.parse.unquote(file).split('?')[0]).name  # '%2F' to '/', split https://url.com/file.txt?auth
348
        if Path(file).is_file():
349
            print(f'Found {url} locally at {file}')  # file already exists
350
        else:
351
            print(f'Downloading {url} to {file}...')
352
            torch.hub.download_url_to_file(url, file)
353
            assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}'  # check
354
        return file
355
    else:  # search
356
        files = []
357
        for d in 'data', 'models', 'utils':  # search directories
358
            files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True))  # find file
359
        assert len(files), f'File not found: {file}'  # assert file was found
360
        assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}"  # assert unique
361
        return files[0]  # return file
362
363
364
def check_dataset(data, autodownload=True):
365
    # Download and/or unzip dataset if not found locally
366
    # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
367
368
    # Download (optional)
369
    extract_dir = ''
370
    if isinstance(data, (str, Path)) and str(data).endswith('.zip'):  # i.e. gs://bucket/dir/coco128.zip
371
        download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
372
        data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
373
        extract_dir, autodownload = data.parent, False
374
375
    # Read yaml (optional)
376
    if isinstance(data, (str, Path)):
377
        with open(data, errors='ignore') as f:
378
            data = yaml.safe_load(f)  # dictionary
379
380
    # Parse yaml
381
    path = extract_dir or Path(data.get('path') or '')  # optional 'path' default to '.'
382
    for k in 'train', 'val', 'test':
383
        if data.get(k):  # prepend path
384
            data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
385
386
    assert 'nc' in data, "Dataset 'nc' key missing."
387
    if 'names' not in data:
388
        data['names'] = [f'class{i}' for i in range(data['nc'])]  # assign class names if missing
389
    train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
390
    if val:
391
        val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path
392
        if not all(x.exists() for x in val):
393
            print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
394
            if s and autodownload:  # download script
395
                root = path.parent if 'path' in data else '..'  # unzip directory i.e. '../'
396
                if s.startswith('http') and s.endswith('.zip'):  # URL
397
                    f = Path(s).name  # filename
398
                    print(f'Downloading {s} to {f}...')
399
                    torch.hub.download_url_to_file(s, f)
400
                    Path(root).mkdir(parents=True, exist_ok=True)  # create root
401
                    ZipFile(f).extractall(path=root)  # unzip
402
                    Path(f).unlink()  # remove zip
403
                    r = None  # success
404
                elif s.startswith('bash '):  # bash script
405
                    print(f'Running {s} ...')
406
                    r = os.system(s)
407
                else:  # python script
408
                    r = exec(s, {'yaml': data})  # return None
409
                print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
410
            else:
411
                raise Exception('Dataset not found.')
412
413
    return data  # dictionary
414
415
416
def url2file(url):
417
    # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
418
    url = str(Path(url)).replace(':/', '://')  # Pathlib turns :// -> :/
419
    file = Path(urllib.parse.unquote(url)).name.split('?')[0]  # '%2F' to '/', split https://url.com/file.txt?auth
420
    return file
421
422
423
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
424
    # Multi-threaded file download and unzip function, used in data.yaml for autodownload
425
    def download_one(url, dir):
426
        # Download 1 file
427
        f = dir / Path(url).name  # filename
428
        if Path(url).is_file():  # exists in current path
429
            Path(url).rename(f)  # move to dir
430
        elif not f.exists():
431
            print(f'Downloading {url} to {f}...')
432
            if curl:
433
                os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -")  # curl download, retry and resume on fail
434
            else:
435
                torch.hub.download_url_to_file(url, f, progress=True)  # torch download
436
        if unzip and f.suffix in ('.zip', '.gz'):
437
            print(f'Unzipping {f}...')
438
            if f.suffix == '.zip':
439
                ZipFile(f).extractall(path=dir)  # unzip
440
            elif f.suffix == '.gz':
441
                os.system(f'tar xfz {f} --directory {f.parent}')  # unzip
442
            if delete:
443
                f.unlink()  # remove zip
444
445
    dir = Path(dir)
446
    dir.mkdir(parents=True, exist_ok=True)  # make directory
447
    if threads > 1:
448
        pool = ThreadPool(threads)
449
        pool.imap(lambda x: download_one(*x), zip(url, repeat(dir)))  # multi-threaded
450
        pool.close()
451
        pool.join()
452
    else:
453
        for u in [url] if isinstance(url, (str, Path)) else url:
454
            download_one(u, dir)
455
456
457
def make_divisible(x, divisor):
458
    # Returns x evenly divisible by divisor
459
    return math.ceil(x / divisor) * divisor
460
461
462
def clean_str(s):
463
    # Cleans a string by replacing special characters with underscore _
464
    return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
465
466
467
def one_cycle(y1=0.0, y2=1.0, steps=100):
468
    # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
469
    return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
470
471
472
def colorstr(*input):
473
    # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')
474
    *args, string = input if len(input) > 1 else ('blue', 'bold', input[0])  # color arguments, string
475
    colors = {'black': '\033[30m',  # basic colors
476
              'red': '\033[31m',
477
              'green': '\033[32m',
478
              'yellow': '\033[33m',
479
              'blue': '\033[34m',
480
              'magenta': '\033[35m',
481
              'cyan': '\033[36m',
482
              'white': '\033[37m',
483
              'bright_black': '\033[90m',  # bright colors
484
              'bright_red': '\033[91m',
485
              'bright_green': '\033[92m',
486
              'bright_yellow': '\033[93m',
487
              'bright_blue': '\033[94m',
488
              'bright_magenta': '\033[95m',
489
              'bright_cyan': '\033[96m',
490
              'bright_white': '\033[97m',
491
              'end': '\033[0m',  # misc
492
              'bold': '\033[1m',
493
              'underline': '\033[4m'}
494
    return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
495
496
497
def labels_to_class_weights(labels, nc=80):
498
    # Get class weights (inverse frequency) from training labels
499
    if labels[0] is None:  # no labels loaded
500
        return torch.Tensor()
501
502
    labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
503
    classes = labels[:, 0].astype(np.int)  # labels = [class xywh]
504
    weights = np.bincount(classes, minlength=nc)  # occurrences per class
505
506
    # Prepend gridpoint count (for uCE training)
507
    # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
508
    # weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start
509
510
    weights[weights == 0] = 1  # replace empty bins with 1
511
    weights = 1 / weights  # number of targets per class
512
    weights /= weights.sum()  # normalize
513
    return torch.from_numpy(weights)
514
515
516
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
517
    # Produces image weights based on class_weights and image contents
518
    class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
519
    image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
520
    # index = random.choices(range(n), weights=image_weights, k=1)  # weight image sample
521
    return image_weights
522
523
524
def coco80_to_coco91_class():  # converts 80-index (val2014) to 91-index (paper)
525
    # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
526
    # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
527
    # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
528
    # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco
529
    # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]  # coco to darknet
530
    x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
531
         35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
532
         64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
533
    return x
534
535
536
def xyxy2xywh(x):
537
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
538
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
539
    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
540
    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
541
    y[:, 2] = x[:, 2] - x[:, 0]  # width
542
    y[:, 3] = x[:, 3] - x[:, 1]  # height
543
    return y
544
545
546
def xywh2xyxy(x):
547
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
548
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
549
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
550
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
551
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
552
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
553
    return y
554
555
556
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
557
    # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
558
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
559
    y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw  # top left x
560
    y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh  # top left y
561
    y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw  # bottom right x
562
    y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh  # bottom right y
563
    return y
564
565
566
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
567
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
568
    if clip:
569
        clip_coords(x, (h - eps, w - eps))  # warning: inplace clip
570
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
571
    y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w  # x center
572
    y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h  # y center
573
    y[:, 2] = (x[:, 2] - x[:, 0]) / w  # width
574
    y[:, 3] = (x[:, 3] - x[:, 1]) / h  # height
575
    return y
576
577
578
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
579
    # Convert normalized segments into pixel segments, shape (n,2)
580
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
581
    y[:, 0] = w * x[:, 0] + padw  # top left x
582
    y[:, 1] = h * x[:, 1] + padh  # top left y
583
    return y
584
585
586
def segment2box(segment, width=640, height=640):
587
    # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
588
    x, y = segment.T  # segment xy
589
    inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
590
    x, y, = x[inside], y[inside]
591
    return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4))  # xyxy
592
593
594
def segments2boxes(segments):
595
    # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
596
    boxes = []
597
    for s in segments:
598
        x, y = s.T  # segment xy
599
        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
600
    return xyxy2xywh(np.array(boxes))  # cls, xywh
601
602
603
def resample_segments(segments, n=1000):
604
    # Up-sample an (n,2) segment
605
    for i, s in enumerate(segments):
606
        x = np.linspace(0, len(s) - 1, n)
607
        xp = np.arange(len(s))
608
        segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T  # segment xy
609
    return segments
610
611
612
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
613
    # Rescale coords (xyxy) from img1_shape to img0_shape
614
    if ratio_pad is None:  # calculate from img0_shape
615
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
616
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
617
    else:
618
        gain = ratio_pad[0][0]
619
        pad = ratio_pad[1]
620
621
    coords[:, [0, 2]] -= pad[0]  # x padding
622
    coords[:, [1, 3]] -= pad[1]  # y padding
623
    coords[:, :4] /= gain
624
    clip_coords(coords, img0_shape)
625
    return coords
626
627
628
def clip_coords(boxes, shape):
629
    # Clip bounding xyxy bounding boxes to image shape (height, width)
630
    if isinstance(boxes, torch.Tensor):  # faster individually
631
        boxes[:, 0].clamp_(0, shape[1])  # x1
632
        boxes[:, 1].clamp_(0, shape[0])  # y1
633
        boxes[:, 2].clamp_(0, shape[1])  # x2
634
        boxes[:, 3].clamp_(0, shape[0])  # y2
635
    else:  # np.array (faster grouped)
636
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1])  # x1, x2
637
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0])  # y1, y2
638
639
640
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
641
                        labels=(), max_det=300):
642
    """Runs Non-Maximum Suppression (NMS) on inference results
643
644
    Returns:
645
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
646
    """
647
648
    nc = prediction.shape[2] - 5  # number of classes
649
    xc = prediction[..., 4] > conf_thres  # candidates
650
651
    # Checks
652
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
653
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
654
655
    # Settings
656
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
657
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
658
    time_limit = 10.0  # seconds to quit after
659
    redundant = True  # require redundant detections
660
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
661
    merge = False  # use merge-NMS
662
663
    t = time.time()
664
    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
665
    for xi, x in enumerate(prediction):  # image index, image inference
666
        # Apply constraints
667
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
668
        x = x[xc[xi]]  # confidence
669
670
        # Cat apriori labels if autolabelling
671
        if labels and len(labels[xi]):
672
            l = labels[xi]
673
            v = torch.zeros((len(l), nc + 5), device=x.device)
674
            v[:, :4] = l[:, 1:5]  # box
675
            v[:, 4] = 1.0  # conf
676
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
677
            x = torch.cat((x, v), 0)
678
679
        # If none remain process next image
680
        if not x.shape[0]:
681
            continue
682
683
        # Compute conf
684
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
685
686
        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
687
        box = xywh2xyxy(x[:, :4])
688
689
        # Detections matrix nx6 (xyxy, conf, cls)
690
        if multi_label:
691
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
692
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
693
        else:  # best class only
694
            conf, j = x[:, 5:].max(1, keepdim=True)
695
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
696
697
        # Filter by class
698
        if classes is not None:
699
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
700
701
        # Apply finite constraint
702
        # if not torch.isfinite(x).all():
703
        #     x = x[torch.isfinite(x).all(1)]
704
705
        # Check shape
706
        n = x.shape[0]  # number of boxes
707
        if not n:  # no boxes
708
            continue
709
        elif n > max_nms:  # excess boxes
710
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence
711
712
        # Batched NMS
713
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
714
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
715
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
716
        if i.shape[0] > max_det:  # limit detections
717
            i = i[:max_det]
718
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
719
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
720
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
721
            weights = iou * scores[None]  # box weights
722
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
723
            if redundant:
724
                i = i[iou.sum(1) > 1]  # require redundancy
725
726
        output[xi] = x[i]
727
        if (time.time() - t) > time_limit:
728
            print(f'WARNING: NMS time limit {time_limit}s exceeded')
729
            break  # time limit exceeded
730
731
    return output
732
733
734
def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
735
    # Strip optimizer from 'f' to finalize training, optionally save as 's'
736
    x = torch.load(f, map_location=torch.device('cpu'))
737
    if x.get('ema'):
738
        x['model'] = x['ema']  # replace model with ema
739
    for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates':  # keys
740
        x[k] = None
741
    x['epoch'] = -1
742
    x['model'].half()  # to FP16
743
    for p in x['model'].parameters():
744
        p.requires_grad = False
745
    torch.save(x, s or f)
746
    mb = os.path.getsize(s or f) / 1E6  # filesize
747
    print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
748
749
750
def print_mutation(results, hyp, save_dir, bucket):
751
    evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
752
    keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
753
            'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys())  # [results + hyps]
754
    keys = tuple(x.strip() for x in keys)
755
    vals = results + tuple(hyp.values())
756
    n = len(keys)
757
758
    # Download (optional)
759
    if bucket:
760
        url = f'gs://{bucket}/evolve.csv'
761
        if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
762
            os.system(f'gsutil cp {url} {save_dir}')  # download evolve.csv if larger than local
763
764
    # Log to evolve.csv
765
    s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n')  # add header
766
    with open(evolve_csv, 'a') as f:
767
        f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
768
769
    # Print to screen
770
    print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
771
    print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
772
773
    # Save yaml
774
    with open(evolve_yaml, 'w') as f:
775
        data = pd.read_csv(evolve_csv)
776
        data = data.rename(columns=lambda x: x.strip())  # strip keys
777
        i = np.argmax(fitness(data.values[:, :7]))  #
778
        f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
779
                f'# Best generation: {i}\n' +
780
                f'# Last generation: {len(data) - 1}\n' +
781
                '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
782
                '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
783
        yaml.safe_dump(hyp, f, sort_keys=False)
784
785
    if bucket:
786
        os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}')  # upload
787
788
789
def apply_classifier(x, model, img, im0):
790
    # Apply a second stage classifier to YOLO outputs
791
    # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
792
    im0 = [im0] if isinstance(im0, np.ndarray) else im0
793
    for i, d in enumerate(x):  # per image
794
        if d is not None and len(d):
795
            d = d.clone()
796
797
            # Reshape and pad cutouts
798
            b = xyxy2xywh(d[:, :4])  # boxes
799
            b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # rectangle to square
800
            b[:, 2:] = b[:, 2:] * 1.3 + 30  # pad
801
            d[:, :4] = xywh2xyxy(b).long()
802
803
            # Rescale boxes from img_size to im0 size
804
            scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
805
806
            # Classes
807
            pred_cls1 = d[:, 5].long()
808
            ims = []
809
            for j, a in enumerate(d):  # per item
810
                cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
811
                im = cv2.resize(cutout, (224, 224))  # BGR
812
                # cv2.imwrite('example%i.jpg' % j, cutout)
813
814
                im = im[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
815
                im = np.ascontiguousarray(im, dtype=np.float32)  # uint8 to float32
816
                im /= 255  # 0 - 255 to 0.0 - 1.0
817
                ims.append(im)
818
819
            pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1)  # classifier prediction
820
            x[i] = x[i][pred_cls1 == pred_cls2]  # retain matching class detections
821
822
    return x
823
824
825
def increment_path(path, exist_ok=False, sep='', mkdir=False):
826
    # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
827
    path = Path(path)  # os-agnostic
828
    if path.exists() and not exist_ok:
829
        path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
830
        dirs = glob.glob(f"{path}{sep}*")  # similar paths
831
        matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
832
        i = [int(m.groups()[0]) for m in matches if m]  # indices
833
        n = max(i) + 1 if i else 2  # increment number
834
        path = Path(f"{path}{sep}{n}{suffix}")  # increment path
835
    if mkdir:
836
        path.mkdir(parents=True, exist_ok=True)  # make directory
837
    return path
838
839
840
# Variables
841
NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns  # terminal window size for tqdm