Diff of /utils/general.py [000000] .. [190ca4]

Switch to unified view

a b/utils/general.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
General utils
4
"""
5
6
import contextlib
7
import glob
8
import inspect
9
import logging
10
import logging.config
11
import math
12
import os
13
import platform
14
import random
15
import re
16
import signal
17
import subprocess
18
import sys
19
import time
20
import urllib
21
from copy import deepcopy
22
from datetime import datetime
23
from itertools import repeat
24
from multiprocessing.pool import ThreadPool
25
from pathlib import Path
26
from subprocess import check_output
27
from tarfile import is_tarfile
28
from typing import Optional
29
from zipfile import ZipFile, is_zipfile
30
31
32
import cv2
33
import numpy as np
34
import pandas as pd
35
import pkg_resources as pkg
36
import torch
37
import torchvision
38
import yaml
39
import torch.nn as nn
40
from torchvision.ops import roi_align
41
42
# Import 'ultralytics' package or install if if missing
43
try:
44
    import ultralytics
45
46
    assert hasattr(ultralytics, '__version__')  # verify package is not directory
47
except (ImportError, AssertionError):
48
    os.system('pip install -U ultralytics')
49
    import ultralytics
50
51
from ultralytics.utils.checks import check_requirements
52
53
from utils import TryExcept, emojis
54
from utils.downloads import curl_download, gsutil_getsize
55
from utils.metrics import box_iou, fitness
56
57
FILE = Path(__file__).resolve()
58
ROOT = FILE.parents[1]  # YOLOv5 root directory
59
RANK = int(os.getenv('RANK', -1))
60
61
# Settings
62
NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLOv5 multiprocessing threads
63
DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets'))  # global datasets directory
64
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true'  # global auto-install mode
65
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true'  # global verbose mode
66
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}'  # tqdm bar format
67
FONT = 'Arial.ttf'  # https://ultralytics.com/assets/Arial.ttf
68
69
torch.set_printoptions(linewidth=320, precision=5, profile='long')
70
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5
71
pd.options.display.max_columns = 10
72
cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
73
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS)  # NumExpr max threads
74
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS)  # OpenMP (PyTorch and SciPy)
75
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # suppress verbose TF compiler warnings in Colab
76
77
78
def is_ascii(s=''):
79
    # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
80
    s = str(s)  # convert list, tuple, None, etc. to str
81
    return len(s.encode().decode('ascii', 'ignore')) == len(s)
82
83
84
def is_chinese(s='人工智能'):
85
    # Is string composed of any Chinese characters?
86
    return bool(re.search('[\u4e00-\u9fff]', str(s)))
87
88
89
def is_colab():
90
    # Is environment a Google Colab instance?
91
    return 'google.colab' in sys.modules
92
93
94
def is_jupyter():
95
    """
96
    Check if the current script is running inside a Jupyter Notebook.
97
    Verified on Colab, Jupyterlab, Kaggle, Paperspace.
98
99
    Returns:
100
        bool: True if running inside a Jupyter Notebook, False otherwise.
101
    """
102
    with contextlib.suppress(Exception):
103
        from IPython import get_ipython
104
        return get_ipython() is not None
105
    return False
106
107
108
def is_kaggle():
109
    # Is environment a Kaggle Notebook?
110
    return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
111
112
113
def is_docker() -> bool:
114
    """Check if the process runs inside a docker container."""
115
    if Path('/.dockerenv').exists():
116
        return True
117
    try:  # check if docker is in control groups
118
        with open('/proc/self/cgroup') as file:
119
            return any('docker' in line for line in file)
120
    except OSError:
121
        return False
122
123
124
def is_writeable(dir, test=False):
125
    # Return True if directory has write permissions, test opening a file with write permissions if test=True
126
    if not test:
127
        return os.access(dir, os.W_OK)  # possible issues on Windows
128
    file = Path(dir) / 'tmp.txt'
129
    try:
130
        with open(file, 'w'):  # open file with write permissions
131
            pass
132
        file.unlink()  # remove file
133
        return True
134
    except OSError:
135
        return False
136
137
138
LOGGING_NAME = 'yolov5'
139
140
141
def set_logging(name=LOGGING_NAME, verbose=True):
142
    # sets up logging for the given name
143
    rank = int(os.getenv('RANK', -1))  # rank in world for Multi-GPU trainings
144
    level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
145
    logging.config.dictConfig({
146
        'version': 1,
147
        'disable_existing_loggers': False,
148
        'formatters': {
149
            name: {
150
                'format': '%(message)s'}},
151
        'handlers': {
152
            name: {
153
                'class': 'logging.StreamHandler',
154
                'formatter': name,
155
                'level': level, }},
156
        'loggers': {
157
            name: {
158
                'level': level,
159
                'handlers': [name],
160
                'propagate': False, }}})
161
162
163
set_logging(LOGGING_NAME)  # run before defining LOGGER
164
LOGGER = logging.getLogger(LOGGING_NAME)  # define globally (used in train.py, val.py, detect.py, etc.)
165
if platform.system() == 'Windows':
166
    for fn in LOGGER.info, LOGGER.warning:
167
        setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x)))  # emoji safe logging
168
169
170
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
171
    # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
172
    env = os.getenv(env_var)
173
    if env:
174
        path = Path(env)  # use environment variable
175
    else:
176
        cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'}  # 3 OS dirs
177
        path = Path.home() / cfg.get(platform.system(), '')  # OS-specific config dir
178
        path = (path if is_writeable(path) else Path('/tmp')) / dir  # GCP and AWS lambda fix, only /tmp is writeable
179
    path.mkdir(exist_ok=True)  # make if required
180
    return path
181
182
183
CONFIG_DIR = user_config_dir()  # Ultralytics settings dir
184
185
186
class Profile(contextlib.ContextDecorator):
187
    # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
188
    def __init__(self, t=0.0, device: torch.device = None):
189
        self.t = t
190
        self.device = device
191
        self.cuda = True if (device and str(device)[:4] == 'cuda') else False
192
193
    def __enter__(self):
194
        self.start = self.time()
195
        return self
196
197
    def __exit__(self, type, value, traceback):
198
        self.dt = self.time() - self.start  # delta-time
199
        self.t += self.dt  # accumulate dt
200
201
    def time(self):
202
        if self.cuda:
203
            torch.cuda.synchronize(self.device)
204
        return time.time()
205
206
207
class Timeout(contextlib.ContextDecorator):
208
    # YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
209
    def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
210
        self.seconds = int(seconds)
211
        self.timeout_message = timeout_msg
212
        self.suppress = bool(suppress_timeout_errors)
213
214
    def _timeout_handler(self, signum, frame):
215
        raise TimeoutError(self.timeout_message)
216
217
    def __enter__(self):
218
        if platform.system() != 'Windows':  # not supported on Windows
219
            signal.signal(signal.SIGALRM, self._timeout_handler)  # Set handler for SIGALRM
220
            signal.alarm(self.seconds)  # start countdown for SIGALRM to be raised
221
222
    def __exit__(self, exc_type, exc_val, exc_tb):
223
        if platform.system() != 'Windows':
224
            signal.alarm(0)  # Cancel SIGALRM if it's scheduled
225
            if self.suppress and exc_type is TimeoutError:  # Suppress TimeoutError
226
                return True
227
228
229
class WorkingDirectory(contextlib.ContextDecorator):
230
    # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
231
    def __init__(self, new_dir):
232
        self.dir = new_dir  # new dir
233
        self.cwd = Path.cwd().resolve()  # current dir
234
235
    def __enter__(self):
236
        os.chdir(self.dir)
237
238
    def __exit__(self, exc_type, exc_val, exc_tb):
239
        os.chdir(self.cwd)
240
241
242
def methods(instance):
243
    # Get class/instance methods
244
    return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith('__')]
245
246
247
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
248
    # Print function arguments (optional args dict)
249
    x = inspect.currentframe().f_back  # previous frame
250
    file, _, func, _, _ = inspect.getframeinfo(x)
251
    if args is None:  # get args automatically
252
        args, _, _, frm = inspect.getargvalues(x)
253
        args = {k: v for k, v in frm.items() if k in args}
254
    try:
255
        file = Path(file).resolve().relative_to(ROOT).with_suffix('')
256
    except ValueError:
257
        file = Path(file).stem
258
    s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
259
    LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
260
261
262
def init_seeds(seed=0, deterministic=False):
263
    # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
264
    random.seed(seed)
265
    np.random.seed(seed)
266
    torch.manual_seed(seed)
267
    torch.cuda.manual_seed(seed)
268
    torch.cuda.manual_seed_all(seed)  # for Multi-GPU, exception safe
269
    # torch.backends.cudnn.benchmark = True  # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
270
    if deterministic and check_version(torch.__version__, '1.12.0'):  # https://github.com/ultralytics/yolov5/pull/8213
271
        torch.use_deterministic_algorithms(True)
272
        torch.backends.cudnn.deterministic = True
273
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
274
        os.environ['PYTHONHASHSEED'] = str(seed)
275
276
277
def intersect_dicts(da, db, exclude=()):
278
    # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
279
    return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
280
281
282
def get_default_args(func):
283
    # Get func() default arguments
284
    signature = inspect.signature(func)
285
    return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
286
287
288
def get_latest_run(search_dir='.'):
289
    # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
290
    last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
291
    return max(last_list, key=os.path.getctime) if last_list else ''
292
293
294
def file_age(path=__file__):
295
    # Return days since last file update
296
    dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime))  # delta
297
    return dt.days  # + dt.seconds / 86400  # fractional days
298
299
300
def file_date(path=__file__):
301
    # Return human-readable file modification date, i.e. '2021-3-26'
302
    t = datetime.fromtimestamp(Path(path).stat().st_mtime)
303
    return f'{t.year}-{t.month}-{t.day}'
304
305
306
def file_size(path):
307
    # Return file/dir size (MB)
308
    mb = 1 << 20  # bytes to MiB (1024 ** 2)
309
    path = Path(path)
310
    if path.is_file():
311
        return path.stat().st_size / mb
312
    elif path.is_dir():
313
        return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
314
    else:
315
        return 0.0
316
317
318
def check_online():
319
    # Check internet connectivity
320
    import socket
321
322
    def run_once():
323
        # Check once
324
        try:
325
            socket.create_connection(('1.1.1.1', 443), 5)  # check host accessibility
326
            return True
327
        except OSError:
328
            return False
329
330
    return run_once() or run_once()  # check twice to increase robustness to intermittent connectivity issues
331
332
333
def git_describe(path=ROOT):  # path must be a directory
334
    # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
335
    try:
336
        assert (Path(path) / '.git').is_dir()
337
        return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
338
    except Exception:
339
        return ''
340
341
342
@TryExcept()
343
@WorkingDirectory(ROOT)
344
def check_git_status(repo='ultralytics/yolov5', branch='master'):
345
    # YOLOv5 status check, recommend 'git pull' if code is out of date
346
    url = f'https://github.com/{repo}'
347
    msg = f', for updates see {url}'
348
    s = colorstr('github: ')  # string
349
    assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
350
    assert check_online(), s + 'skipping check (offline)' + msg
351
352
    splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
353
    matches = [repo in s for s in splits]
354
    if any(matches):
355
        remote = splits[matches.index(True) - 1]
356
    else:
357
        remote = 'ultralytics'
358
        check_output(f'git remote add {remote} {url}', shell=True)
359
    check_output(f'git fetch {remote}', shell=True, timeout=5)  # git fetch
360
    local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip()  # checked out
361
    n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True))  # commits behind
362
    if n > 0:
363
        pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
364
        s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use '{pull}' or 'git clone {url}' to update."
365
    else:
366
        s += f'up to date with {url} ✅'
367
    LOGGER.info(s)
368
369
370
@WorkingDirectory(ROOT)
371
def check_git_info(path='.'):
372
    # YOLOv5 git info check, return {remote, branch, commit}
373
    check_requirements('gitpython')
374
    import git
375
    try:
376
        repo = git.Repo(path)
377
        remote = repo.remotes.origin.url.replace('.git', '')  # i.e. 'https://github.com/ultralytics/yolov5'
378
        commit = repo.head.commit.hexsha  # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
379
        try:
380
            branch = repo.active_branch.name  # i.e. 'main'
381
        except TypeError:  # not on any branch
382
            branch = None  # i.e. 'detached HEAD' state
383
        return {'remote': remote, 'branch': branch, 'commit': commit}
384
    except git.exc.InvalidGitRepositoryError:  # path is not a git dir
385
        return {'remote': None, 'branch': None, 'commit': None}
386
387
388
def check_python(minimum='3.8.0'):
389
    # Check current python version vs. required python version
390
    check_version(platform.python_version(), minimum, name='Python ', hard=True)
391
392
393
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
394
    # Check version vs. required version
395
    current, minimum = (pkg.parse_version(x) for x in (current, minimum))
396
    result = (current == minimum) if pinned else (current >= minimum)  # bool
397
    s = f'WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed'  # string
398
    if hard:
399
        assert result, emojis(s)  # assert min requirements met
400
    if verbose and not result:
401
        LOGGER.warning(s)
402
    return result
403
404
405
def check_img_size(imgsz, s=32, floor=0):
406
    # Verify image size is a multiple of stride s in each dimension
407
    if isinstance(imgsz, int):  # integer i.e. img_size=640
408
        new_size = max(make_divisible(imgsz, int(s)), floor)
409
    else:  # list i.e. img_size=[640, 480]
410
        imgsz = list(imgsz)  # convert to list if tuple
411
        new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
412
    if new_size != imgsz:
413
        LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
414
    return new_size
415
416
417
def check_imshow(warn=False):
418
    # Check if environment supports image displays
419
    try:
420
        assert not is_jupyter()
421
        assert not is_docker()
422
        cv2.imshow('test', np.zeros((1, 1, 3)))
423
        cv2.waitKey(1)
424
        cv2.destroyAllWindows()
425
        cv2.waitKey(1)
426
        return True
427
    except Exception as e:
428
        if warn:
429
            LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
430
        return False
431
432
433
def check_suffix(file='yolov5s.pt', suffix=('.pt', ), msg=''):
434
    # Check file(s) for acceptable suffix
435
    if file and suffix:
436
        if isinstance(suffix, str):
437
            suffix = [suffix]
438
        for f in file if isinstance(file, (list, tuple)) else [file]:
439
            s = Path(f).suffix.lower()  # file suffix
440
            if len(s):
441
                assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
442
443
444
def check_yaml(file, suffix=('.yaml', '.yml')):
445
    # Search/download YAML file (if necessary) and return path, checking suffix
446
    return check_file(file, suffix)
447
448
449
def check_file(file, suffix=''):
450
    # Search/download file (if necessary) and return path
451
    check_suffix(file, suffix)  # optional
452
    file = str(file)  # convert to str()
453
    if os.path.isfile(file) or not file:  # exists
454
        return file
455
    elif file.startswith(('http:/', 'https:/')):  # download
456
        url = file  # warning: Pathlib turns :// -> :/
457
        file = Path(urllib.parse.unquote(file).split('?')[0]).name  # '%2F' to '/', split https://url.com/file.txt?auth
458
        if os.path.isfile(file):
459
            LOGGER.info(f'Found {url} locally at {file}')  # file already exists
460
        else:
461
            LOGGER.info(f'Downloading {url} to {file}...')
462
            torch.hub.download_url_to_file(url, file)
463
            assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}'  # check
464
        return file
465
    elif file.startswith('clearml://'):  # ClearML Dataset ID
466
        assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
467
        return file
468
    else:  # search
469
        files = []
470
        for d in 'data', 'models', 'utils':  # search directories
471
            files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True))  # find file
472
        assert len(files), f'File not found: {file}'  # assert file was found
473
        assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}"  # assert unique
474
        return files[0]  # return file
475
476
477
def check_font(font=FONT, progress=False):
478
    # Download font to CONFIG_DIR if necessary
479
    font = Path(font)
480
    file = CONFIG_DIR / font.name
481
    if not font.exists() and not file.exists():
482
        url = f'https://ultralytics.com/assets/{font.name}'
483
        LOGGER.info(f'Downloading {url} to {file}...')
484
        torch.hub.download_url_to_file(url, str(file), progress=progress)
485
486
487
def check_dataset(data, autodownload=True):
488
    # Download, check and/or unzip dataset if not found locally
489
490
    # Download (optional)
491
    extract_dir = ''
492
    if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
493
        download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
494
        data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
495
        extract_dir, autodownload = data.parent, False
496
497
    # Read yaml (optional)
498
    if isinstance(data, (str, Path)):
499
        data = yaml_load(data)  # dictionary
500
501
    # Checks
502
    for k in 'train', 'val', 'names':
503
        assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
504
    if isinstance(data['names'], (list, tuple)):  # old array format
505
        data['names'] = dict(enumerate(data['names']))  # convert to dict
506
    assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
507
    data['nc'] = len(data['names'])
508
509
    # Resolve paths
510
    path = Path(extract_dir or data.get('path') or '')  # optional 'path' default to '.'
511
    if not path.is_absolute():
512
        path = (ROOT / path).resolve()
513
        data['path'] = path  # download scripts
514
    for k in 'train', 'val', 'test':
515
        if data.get(k):  # prepend path
516
            if isinstance(data[k], str):
517
                x = (path / data[k]).resolve()
518
                if not x.exists() and data[k].startswith('../'):
519
                    x = (path / data[k][3:]).resolve()
520
                data[k] = str(x)
521
            else:
522
                data[k] = [str((path / x).resolve()) for x in data[k]]
523
524
    # Parse yaml
525
    train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
526
    if val:
527
        val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path
528
        if not all(x.exists() for x in val):
529
            LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
530
            if not s or not autodownload:
531
                raise Exception('Dataset not found ❌')
532
            t = time.time()
533
            if s.startswith('http') and s.endswith('.zip'):  # URL
534
                f = Path(s).name  # filename
535
                LOGGER.info(f'Downloading {s} to {f}...')
536
                torch.hub.download_url_to_file(s, f)
537
                Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True)  # create root
538
                unzip_file(f, path=DATASETS_DIR)  # unzip
539
                Path(f).unlink()  # remove zip
540
                r = None  # success
541
            elif s.startswith('bash '):  # bash script
542
                LOGGER.info(f'Running {s} ...')
543
                r = subprocess.run(s, shell=True)
544
            else:  # python script
545
                r = exec(s, {'yaml': data})  # return None
546
            dt = f'({round(time.time() - t, 1)}s)'
547
            s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
548
            LOGGER.info(f'Dataset download {s}')
549
    check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True)  # download fonts
550
    return data  # dictionary
551
552
553
def check_amp(model):
554
    # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
555
    from models.common import AutoShape, DetectMultiBackend
556
557
    def amp_allclose(model, im):
558
        # All close FP32 vs AMP results
559
        m = AutoShape(model, verbose=False)  # model
560
        a = m(im).xywhn[0]  # FP32 inference
561
        m.amp = True
562
        b = m(im).xywhn[0]  # AMP inference
563
        return a.shape == b.shape and torch.allclose(a, b, atol=0.1)  # close to 10% absolute tolerance
564
565
    prefix = colorstr('AMP: ')
566
    device = next(model.parameters()).device  # get model device
567
    if device.type in ('cpu', 'mps'):
568
        return False  # AMP only used on CUDA devices
569
    f = ROOT / 'data' / 'images' / 'bus.jpg'  # image to check
570
    im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
571
    try:
572
        assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
573
        LOGGER.info(f'{prefix}checks passed ✅')
574
        return True
575
    except Exception:
576
        help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
577
        LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
578
        return False
579
580
581
def yaml_load(file='data.yaml'):
582
    # Single-line safe yaml loading
583
    with open(file, errors='ignore') as f:
584
        return yaml.safe_load(f)
585
586
587
def yaml_save(file='data.yaml', data={}):
588
    # Single-line safe yaml saving
589
    with open(file, 'w') as f:
590
        yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
591
592
593
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
594
    # Unzip a *.zip file to path/, excluding files containing strings in exclude list
595
    if path is None:
596
        path = Path(file).parent  # default path
597
    with ZipFile(file) as zipObj:
598
        for f in zipObj.namelist():  # list all archived filenames in the zip
599
            if all(x not in f for x in exclude):
600
                zipObj.extract(f, path=path)
601
602
603
def url2file(url):
604
    # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
605
    url = str(Path(url)).replace(':/', '://')  # Pathlib turns :// -> :/
606
    return Path(urllib.parse.unquote(url)).name.split('?')[0]  # '%2F' to '/', split https://url.com/file.txt?auth
607
608
609
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
610
    # Multithreaded file download and unzip function, used in data.yaml for autodownload
611
    def download_one(url, dir):
612
        # Download 1 file
613
        success = True
614
        if os.path.isfile(url):
615
            f = Path(url)  # filename
616
        else:  # does not exist
617
            f = dir / Path(url).name
618
            LOGGER.info(f'Downloading {url} to {f}...')
619
            for i in range(retry + 1):
620
                if curl:
621
                    success = curl_download(url, f, silent=(threads > 1))
622
                else:
623
                    torch.hub.download_url_to_file(url, f, progress=threads == 1)  # torch download
624
                    success = f.is_file()
625
                if success:
626
                    break
627
                elif i < retry:
628
                    LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
629
                else:
630
                    LOGGER.warning(f'❌ Failed to download {url}...')
631
632
        if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
633
            LOGGER.info(f'Unzipping {f}...')
634
            if is_zipfile(f):
635
                unzip_file(f, dir)  # unzip
636
            elif is_tarfile(f):
637
                subprocess.run(['tar', 'xf', f, '--directory', f.parent], check=True)  # unzip
638
            elif f.suffix == '.gz':
639
                subprocess.run(['tar', 'xfz', f, '--directory', f.parent], check=True)  # unzip
640
            if delete:
641
                f.unlink()  # remove zip
642
643
    dir = Path(dir)
644
    dir.mkdir(parents=True, exist_ok=True)  # make directory
645
    if threads > 1:
646
        pool = ThreadPool(threads)
647
        pool.imap(lambda x: download_one(*x), zip(url, repeat(dir)))  # multithreaded
648
        pool.close()
649
        pool.join()
650
    else:
651
        for u in [url] if isinstance(url, (str, Path)) else url:
652
            download_one(u, dir)
653
654
655
def make_divisible(x, divisor):
656
    # Returns nearest x divisible by divisor
657
    if isinstance(divisor, torch.Tensor):
658
        divisor = int(divisor.max())  # to int
659
    return math.ceil(x / divisor) * divisor
660
661
662
def clean_str(s):
663
    # Cleans a string by replacing special characters with underscore _
664
    return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
665
666
667
def one_cycle(y1=0.0, y2=1.0, steps=100):
668
    # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
669
    return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
670
671
672
def colorstr(*input):
673
    # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')
674
    *args, string = input if len(input) > 1 else ('blue', 'bold', input[0])  # color arguments, string
675
    colors = {
676
        'black': '\033[30m',  # basic colors
677
        'red': '\033[31m',
678
        'green': '\033[32m',
679
        'yellow': '\033[33m',
680
        'blue': '\033[34m',
681
        'magenta': '\033[35m',
682
        'cyan': '\033[36m',
683
        'white': '\033[37m',
684
        'bright_black': '\033[90m',  # bright colors
685
        'bright_red': '\033[91m',
686
        'bright_green': '\033[92m',
687
        'bright_yellow': '\033[93m',
688
        'bright_blue': '\033[94m',
689
        'bright_magenta': '\033[95m',
690
        'bright_cyan': '\033[96m',
691
        'bright_white': '\033[97m',
692
        'end': '\033[0m',  # misc
693
        'bold': '\033[1m',
694
        'underline': '\033[4m'}
695
    return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
696
697
698
def labels_to_class_weights(labels, nc=80):
699
    # Get class weights (inverse frequency) from training labels
700
    if labels[0] is None:  # no labels loaded
701
        return torch.Tensor()
702
703
    labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
704
    classes = labels[:, 0].astype(int)  # labels = [class xywh]
705
    weights = np.bincount(classes, minlength=nc)  # occurrences per class
706
707
    # Prepend gridpoint count (for uCE training)
708
    # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
709
    # weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start
710
711
    weights[weights == 0] = 1  # replace empty bins with 1
712
    weights = 1 / weights  # number of targets per class
713
    weights /= weights.sum()  # normalize
714
    return torch.from_numpy(weights).float()
715
716
717
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
718
    # Produces image weights based on class_weights and image contents
719
    # Usage: index = random.choices(range(n), weights=image_weights, k=1)  # weighted image sample
720
    class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
721
    return (class_weights.reshape(1, nc) * class_counts).sum(1)
722
723
724
def coco80_to_coco91_class():  # converts 80-index (val2014) to 91-index (paper)
725
    # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
726
    # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
727
    # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
728
    # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]  # darknet to coco
729
    # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]  # coco to darknet
730
    return [
731
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
732
        35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
733
        64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
734
735
736
def xyxy2xywh(x):
737
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
738
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
739
    y[..., 0] = (x[..., 0] + x[..., 2]) / 2  # x center
740
    y[..., 1] = (x[..., 1] + x[..., 3]) / 2  # y center
741
    y[..., 2] = x[..., 2] - x[..., 0]  # width
742
    y[..., 3] = x[..., 3] - x[..., 1]  # height
743
    return y
744
745
746
def xywh2xyxy(x):
747
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
748
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
749
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
750
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
751
    y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
752
    y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
753
    return y
754
755
756
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
757
    # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
758
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
759
    y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw  # top left x
760
    y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh  # top left y
761
    y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw  # bottom right x
762
    y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh  # bottom right y
763
    return y
764
765
766
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
767
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
768
    if clip:
769
        clip_boxes(x, (h - eps, w - eps))  # warning: inplace clip
770
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
771
    y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w  # x center
772
    y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h  # y center
773
    y[..., 2] = (x[..., 2] - x[..., 0]) / w  # width
774
    y[..., 3] = (x[..., 3] - x[..., 1]) / h  # height
775
    return y
776
777
778
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
779
    # Convert normalized segments into pixel segments, shape (n,2)
780
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
781
    y[..., 0] = w * x[..., 0] + padw  # top left x
782
    y[..., 1] = h * x[..., 1] + padh  # top left y
783
    return y
784
785
786
def segment2box(segment, width=640, height=640):
787
    # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
788
    x, y = segment.T  # segment xy
789
    inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
790
    x, y, = x[inside], y[inside]
791
    return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4))  # xyxy
792
793
794
def segments2boxes(segments):
795
    # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
796
    boxes = []
797
    for s in segments:
798
        x, y = s.T  # segment xy
799
        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
800
    return xyxy2xywh(np.array(boxes))  # cls, xywh
801
802
803
def resample_segments(segments, n=1000):
804
    # Up-sample an (n,2) segment
805
    for i, s in enumerate(segments):
806
        s = np.concatenate((s, s[0:1, :]), axis=0)
807
        x = np.linspace(0, len(s) - 1, n)
808
        xp = np.arange(len(s))
809
        segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T  # segment xy
810
    return segments
811
812
813
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
814
    # Rescale boxes (xyxy) from img1_shape to img0_shape
815
    if ratio_pad is None:  # calculate from img0_shape
816
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
817
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
818
    else:
819
        gain = ratio_pad[0][0]
820
        pad = ratio_pad[1]
821
822
    boxes[..., [0, 2]] -= pad[0]  # x padding
823
    boxes[..., [1, 3]] -= pad[1]  # y padding
824
    boxes[..., :4] /= gain
825
    clip_boxes(boxes, img0_shape)
826
    return boxes
827
828
829
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
830
    # Rescale coords (xyxy) from img1_shape to img0_shape
831
    if ratio_pad is None:  # calculate from img0_shape
832
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
833
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
834
    else:
835
        gain = ratio_pad[0][0]
836
        pad = ratio_pad[1]
837
838
    segments[:, 0] -= pad[0]  # x padding
839
    segments[:, 1] -= pad[1]  # y padding
840
    segments /= gain
841
    clip_segments(segments, img0_shape)
842
    if normalize:
843
        segments[:, 0] /= img0_shape[1]  # width
844
        segments[:, 1] /= img0_shape[0]  # height
845
    return segments
846
847
848
def clip_boxes(boxes, shape):
849
    # Clip boxes (xyxy) to image shape (height, width)
850
    if isinstance(boxes, torch.Tensor):  # faster individually
851
        boxes[..., 0].clamp_(0, shape[1])  # x1
852
        boxes[..., 1].clamp_(0, shape[0])  # y1
853
        boxes[..., 2].clamp_(0, shape[1])  # x2
854
        boxes[..., 3].clamp_(0, shape[0])  # y2
855
    else:  # np.array (faster grouped)
856
        boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2
857
        boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2
858
859
860
def clip_segments(segments, shape):
861
    # Clip segments (xy1,xy2,...) to image shape (height, width)
862
    if isinstance(segments, torch.Tensor):  # faster individually
863
        segments[:, 0].clamp_(0, shape[1])  # x
864
        segments[:, 1].clamp_(0, shape[0])  # y
865
    else:  # np.array (faster grouped)
866
        segments[:, 0] = segments[:, 0].clip(0, shape[1])  # x
867
        segments[:, 1] = segments[:, 1].clip(0, shape[0])  # y
868
869
870
def non_max_suppression(
871
        prediction,
872
        conf_thres=0.25,
873
        iou_thres=0.45,
874
        classes=None,
875
        agnostic=False,
876
        multi_label=False,
877
        labels=(),
878
        max_det=300,
879
        nm=0,  # number of masks
880
):
881
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
882
883
    Returns:
884
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
885
    """
886
887
    # Checks
888
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
889
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
890
    if isinstance(prediction, (list, tuple)):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
891
        prediction = prediction[0]  # select only inference output
892
893
    device = prediction.device
894
    mps = 'mps' in device.type  # Apple MPS
895
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
896
        prediction = prediction.cpu()
897
    bs = prediction.shape[0]  # batch size
898
    nc = prediction.shape[2] - nm - 5  # number of classes
899
    xc = prediction[..., 4] > conf_thres  # candidates
900
901
    # Settings
902
    # min_wh = 2  # (pixels) minimum box width and height
903
    max_wh = 7680  # (pixels) maximum box width and height
904
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
905
    time_limit = 0.5 + 0.05 * bs  # seconds to quit after
906
    redundant = True  # require redundant detections
907
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
908
    merge = False  # use merge-NMS
909
910
    t = time.time()
911
    mi = 5 + nc  # mask start index
912
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
913
    for xi, x in enumerate(prediction):  # image index, image inference
914
        # Apply constraints
915
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
916
        x = x[xc[xi]]  # confidence
917
918
        # Cat apriori labels if autolabelling
919
        if labels and len(labels[xi]):
920
            lb = labels[xi]
921
            v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
922
            v[:, :4] = lb[:, 1:5]  # box
923
            v[:, 4] = 1.0  # conf
924
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
925
            x = torch.cat((x, v), 0)
926
927
        # If none remain process next image
928
        if not x.shape[0]:
929
            continue
930
931
        # Compute conf
932
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
933
934
        # Box/Mask
935
        box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
936
        mask = x[:, mi:]  # zero columns if no masks
937
938
        # Detections matrix nx6 (xyxy, conf, cls)
939
        if multi_label:
940
            i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
941
            x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
942
        else:  # best class only
943
            conf, j = x[:, 5:mi].max(1, keepdim=True)
944
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
945
946
        # Filter by class
947
        if classes is not None:
948
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
949
950
        # Apply finite constraint
951
        # if not torch.isfinite(x).all():
952
        #     x = x[torch.isfinite(x).all(1)]
953
954
        # Check shape
955
        n = x.shape[0]  # number of boxes
956
        if not n:  # no boxes
957
            continue
958
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes
959
960
        # Batched NMS
961
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
962
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
963
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
964
        i = i[:max_det]  # limit detections
965
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
966
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
967
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
968
            weights = iou * scores[None]  # box weights
969
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
970
            if redundant:
971
                i = i[iou.sum(1) > 1]  # require redundancy
972
973
        output[xi] = x[i]
974
        if mps:
975
            output[xi] = output[xi].to(device)
976
        if (time.time() - t) > time_limit:
977
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
978
            break  # time limit exceeded
979
980
    return output
981
982
983
def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
984
    # Strip optimizer from 'f' to finalize training, optionally save as 's'
985
    x = torch.load(f, map_location=torch.device('cpu'))
986
    if x.get('ema'):
987
        x['model'] = x['ema']  # replace model with ema
988
    for k in 'optimizer', 'best_fitness', 'ema', 'updates':  # keys
989
        x[k] = None
990
    x['epoch'] = -1
991
    x['model'].half()  # to FP16
992
    for p in x['model'].parameters():
993
        p.requires_grad = False
994
    torch.save(x, s or f)
995
    mb = os.path.getsize(s or f) / 1E6  # filesize
996
    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
997
998
999
def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
1000
    evolve_csv = save_dir / 'evolve.csv'
1001
    evolve_yaml = save_dir / 'hyp_evolve.yaml'
1002
    keys = tuple(keys) + tuple(hyp.keys())  # [results + hyps]
1003
    keys = tuple(x.strip() for x in keys)
1004
    vals = results + tuple(hyp.values())
1005
    n = len(keys)
1006
1007
    # Download (optional)
1008
    if bucket:
1009
        url = f'gs://{bucket}/evolve.csv'
1010
        if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
1011
            subprocess.run(['gsutil', 'cp', f'{url}', f'{save_dir}'])  # download evolve.csv if larger than local
1012
1013
    # Log to evolve.csv
1014
    s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n')  # add header
1015
    with open(evolve_csv, 'a') as f:
1016
        f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
1017
1018
    # Save yaml
1019
    with open(evolve_yaml, 'w') as f:
1020
        data = pd.read_csv(evolve_csv, skipinitialspace=True)
1021
        data = data.rename(columns=lambda x: x.strip())  # strip keys
1022
        i = np.argmax(fitness(data.values[:, :4]))  #
1023
        generations = len(data)
1024
        f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
1025
                f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
1026
                '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
1027
        yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
1028
1029
    # Print to screen
1030
    LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
1031
                ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
1032
                                                                                         for x in vals) + '\n\n')
1033
1034
    if bucket:
1035
        subprocess.run(['gsutil', 'cp', f'{evolve_csv}', f'{evolve_yaml}', f'gs://{bucket}'])  # upload
1036
1037
1038
def apply_classifier(x, model, img, im0):
1039
    # Apply a second stage classifier to YOLO outputs
1040
    # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
1041
    im0 = [im0] if isinstance(im0, np.ndarray) else im0
1042
    for i, d in enumerate(x):  # per image
1043
        if d is not None and len(d):
1044
            d = d.clone()
1045
1046
            # Reshape and pad cutouts
1047
            b = xyxy2xywh(d[:, :4])  # boxes
1048
            b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # rectangle to square
1049
            b[:, 2:] = b[:, 2:] * 1.3 + 30  # pad
1050
            d[:, :4] = xywh2xyxy(b).long()
1051
1052
            # Rescale boxes from img_size to im0 size
1053
            scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
1054
1055
            # Classes
1056
            pred_cls1 = d[:, 5].long()
1057
            ims = []
1058
            for a in d:
1059
                cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
1060
                im = cv2.resize(cutout, (224, 224))  # BGR
1061
1062
                im = im[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
1063
                im = np.ascontiguousarray(im, dtype=np.float32)  # uint8 to float32
1064
                im /= 255  # 0 - 255 to 0.0 - 1.0
1065
                ims.append(im)
1066
1067
            pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1)  # classifier prediction
1068
            x[i] = x[i][pred_cls1 == pred_cls2]  # retain matching class detections
1069
1070
    return x
1071
1072
1073
def increment_path(path, exist_ok=False, sep='', mkdir=False):
1074
    # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
1075
    path = Path(path)  # os-agnostic
1076
    if path.exists() and not exist_ok:
1077
        path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
1078
1079
        # Method 1
1080
        for n in range(2, 9999):
1081
            p = f'{path}{sep}{n}{suffix}'  # increment path
1082
            if not os.path.exists(p):  #
1083
                break
1084
        path = Path(p)
1085
1086
        # Method 2 (deprecated)
1087
        # dirs = glob.glob(f"{path}{sep}*")  # similar paths
1088
        # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
1089
        # i = [int(m.groups()[0]) for m in matches if m]  # indices
1090
        # n = max(i) + 1 if i else 2  # increment number
1091
        # path = Path(f"{path}{sep}{n}{suffix}")  # increment path
1092
1093
    if mkdir:
1094
        path.mkdir(parents=True, exist_ok=True)  # make directory
1095
1096
    return path
1097
1098
1099
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------------
1100
imshow_ = cv2.imshow  # copy to avoid recursion errors
1101
1102
1103
def imread(filename, flags=cv2.IMREAD_COLOR):
1104
    return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
1105
1106
1107
def imwrite(filename, img):
1108
    try:
1109
        cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
1110
        return True
1111
    except Exception:
1112
        return False
1113
1114
1115
def imshow(path, im):
1116
    imshow_(path.encode('unicode_escape').decode(), im)
1117
1118
1119
if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
1120
    cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow  # redefine
1121
1122
1123
def get_object_level_feature_maps(feature_map, targets):
1124
    feature_map_shape = feature_map.shape[2:]
1125
1126
    # Assuming targets contain batch, class, x_center, y_center, width, height
1127
    x_center = targets[:, 2] * feature_map_shape[1]
1128
    y_center = targets[:, 3] * feature_map_shape[0]
1129
    width = targets[:, 4] * feature_map_shape[1]
1130
    height = targets[:, 5] * feature_map_shape[0]
1131
1132
    # Calculate pixel coordinates for the bounding boxes
1133
    x_min = torch.clamp((x_center - width / 2).int(), 0, feature_map_shape[1] - 1)
1134
    y_min = torch.clamp((y_center - height / 2).int(), 0, feature_map_shape[0] - 1)
1135
    x_max = torch.clamp((x_center + width / 2).int(), 0, feature_map_shape[1] - 1)
1136
    y_max = torch.clamp((y_center + height / 2).int(), 0, feature_map_shape[0] - 1)
1137
1138
    # Extract regions from the feature_map based on the bounding boxes
1139
    extracted_regions = [feature_map[:, :, y_min[i]:y_max[i] + 1, x_min[i]:x_max[i] + 1] for i in range(targets.shape[0])]
1140
1141
    return extracted_regions
1142
1143
def get_object_level_feature_maps2(feature_map, targets):
1144
    feature_map_shape = feature_map.shape[2:]
1145
1146
    # Assuming targets contain batch, class, x_center, y_center, width, height
1147
    x_center = targets[:, 1] * feature_map_shape[1]
1148
    y_center = targets[:, 2] * feature_map_shape[0]
1149
    width = targets[:, 3] * feature_map_shape[1]
1150
    height = targets[:, 4] * feature_map_shape[0]
1151
1152
    # Calculate pixel coordinates for the bounding boxes
1153
    x_min = torch.clamp((x_center - width / 2).int(), 0, feature_map_shape[1] - 1)
1154
    y_min = torch.clamp((y_center - height / 2).int(), 0, feature_map_shape[0] - 1)
1155
    x_max = torch.clamp((x_center + width / 2).int(), 0, feature_map_shape[1] - 1)
1156
    y_max = torch.clamp((y_center + height / 2).int(), 0, feature_map_shape[0] - 1)
1157
1158
    # Extract regions from the feature_map based on the bounding boxes
1159
    extracted_regions = [feature_map[:, :, y_min[i]:y_max[i] + 1, x_min[i]:x_max[i] + 1] for i in range(targets.shape[0])]
1160
1161
    return extracted_regions
1162
1163
def extract_roi_features(concatenated_features, resize_boxes):
1164
    """
1165
    Extracts regions of interest (ROIs) from the concatenated_features based on resize_boxes.
1166
1167
    Args:
1168
        concatenated_features (torch.Tensor): Feature map with shape [batch, channels, height, width].
1169
        resize_boxes (torch.Tensor): Boxes with shape [num_boxes, 5], where each row is [batch, x1, y1, x2, y2].
1170
1171
    Returns:
1172
        torch.Tensor: Tensor containing the ROI features with shape [num_boxes, channels, roi_height, roi_width].
1173
    """
1174
    # Initialize a list to store ROI features for each box
1175
    roi_features_list = []
1176
1177
    for box_idx in range(resize_boxes.size(0)):
1178
        # Extract box coordinates
1179
        box_coords = resize_boxes[box_idx, 1:]
1180
1181
        # Calculate the spatial coordinates of the box
1182
        box_x1, box_y1, box_x2, box_y2 = box_coords
1183
        roi_x1 = (box_x1 / concatenated_features.size(3)) * concatenated_features.size(3)
1184
        roi_y1 = (box_y1 / concatenated_features.size(2)) * concatenated_features.size(2)
1185
        roi_x2 = (box_x2 / concatenated_features.size(3)) * concatenated_features.size(3)
1186
        roi_y2 = (box_y2 / concatenated_features.size(2)) * concatenated_features.size(2)
1187
1188
        # Convert to integer indices
1189
        roi_x1, roi_y1, roi_x2, roi_y2 = map(int, [roi_x1, roi_y1, roi_x2, roi_y2])
1190
1191
        # Extract ROI from the feature map
1192
        roi_features = concatenated_features[:, :, roi_y1:roi_y2, roi_x1:roi_x2]
1193
1194
        # Append the ROI features to the list
1195
        roi_features_list.append(roi_features)
1196
1197
1198
    return roi_features_list
1199
1200
import numpy as np
1201
import matplotlib.pyplot as plt
1202
import matplotlib.patches as patches
1203
import torch
1204
1205
def plot_multi_channel_feature_map_with_boxes(feature_map, boxes, channels, title, save_path=None):
1206
    fig, axs = plt.subplots(1, len(channels) + 1, figsize=(12, 4))
1207
1208
    for i, channel in enumerate(channels):
1209
        axs[i].imshow(feature_map[0, channel].cpu().detach().numpy(), cmap='viridis', aspect='auto')
1210
        axs[i].set_title(f'Channel {channel}')
1211
1212
    # Plot bounding boxes on the last axis
1213
    axs[-1].imshow(feature_map[0, channels[-1]].cpu().detach().numpy(), cmap='viridis', aspect='auto')
1214
    axs[-1].set_title('Bounding Boxes')
1215
1216
    for box in range(len(boxes.shape)):
1217
        xmin, ymin, xmax, ymax = boxes.cpu().detach().numpy()
1218
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor='r', facecolor='none')
1219
        axs[-1].add_patch(rect)
1220
1221
    fig.suptitle(title)
1222
    
1223
    # Save the image if save_path is provided
1224
    if save_path:
1225
        plt.savefig(save_path)
1226
        print(f"Image saved at: {save_path}")
1227
    else:
1228
        plt.show()
1229
1230
                # Denormalize the box
1231
def xywh_to_xyxy(xywh):
1232
    x_center, y_center, width, height = xywh
1233
    x_min = x_center - width / 2
1234
    y_min = y_center - height / 2
1235
    x_max = x_center + width / 2
1236
    y_max = y_center + height / 2
1237
    return torch.tensor([x_min, y_min, x_max, y_max])
1238
1239
1240
def get_fixed_xyxy(normalized_xyxy,int_feat):    
1241
    x_min, y_min, x_max, y_max = normalized_xyxy.int()
1242
1243
    if x_min == x_max:
1244
        x_max += 1
1245
1246
    if y_min == y_max:
1247
        y_max += 1
1248
1249
    if x_min == x_max and x_max == int_feat.size(2):
1250
        x_min -= 1
1251
                        
1252
    if y_min == y_max and y_max == int_feat.size(1):
1253
        y_min -= 1
1254
    
1255
    return x_min, y_min, x_max, y_max
1256
# Variables ------------------------------------------------------------------------------------------------------------