[190ca4]: / utils / general.py

Download this file

1257 lines (1009 with data), 51.3 kB

   1
   2
   3
   4
   5
   6
   7
   8
   9
  10
  11
  12
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
"""
General utils
"""
import contextlib
import glob
import inspect
import logging
import logging.config
import math
import os
import platform
import random
import re
import signal
import subprocess
import sys
import time
import urllib
from copy import deepcopy
from datetime import datetime
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from subprocess import check_output
from tarfile import is_tarfile
from typing import Optional
from zipfile import ZipFile, is_zipfile
import cv2
import numpy as np
import pandas as pd
import pkg_resources as pkg
import torch
import torchvision
import yaml
import torch.nn as nn
from torchvision.ops import roi_align
# Import 'ultralytics' package or install if if missing
try:
import ultralytics
assert hasattr(ultralytics, '__version__') # verify package is not directory
except (ImportError, AssertionError):
os.system('pip install -U ultralytics')
import ultralytics
from ultralytics.utils.checks import check_requirements
from utils import TryExcept, emojis
from utils.downloads import curl_download, gsutil_getsize
from utils.metrics import box_iou, fitness
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
RANK = int(os.getenv('RANK', -1))
# Settings
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
def is_ascii(s=''):
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
s = str(s) # convert list, tuple, None, etc. to str
return len(s.encode().decode('ascii', 'ignore')) == len(s)
def is_chinese(s='人工智能'):
# Is string composed of any Chinese characters?
return bool(re.search('[\u4e00-\u9fff]', str(s)))
def is_colab():
# Is environment a Google Colab instance?
return 'google.colab' in sys.modules
def is_jupyter():
"""
Check if the current script is running inside a Jupyter Notebook.
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
Returns:
bool: True if running inside a Jupyter Notebook, False otherwise.
"""
with contextlib.suppress(Exception):
from IPython import get_ipython
return get_ipython() is not None
return False
def is_kaggle():
# Is environment a Kaggle Notebook?
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
def is_docker() -> bool:
"""Check if the process runs inside a docker container."""
if Path('/.dockerenv').exists():
return True
try: # check if docker is in control groups
with open('/proc/self/cgroup') as file:
return any('docker' in line for line in file)
except OSError:
return False
def is_writeable(dir, test=False):
# Return True if directory has write permissions, test opening a file with write permissions if test=True
if not test:
return os.access(dir, os.W_OK) # possible issues on Windows
file = Path(dir) / 'tmp.txt'
try:
with open(file, 'w'): # open file with write permissions
pass
file.unlink() # remove file
return True
except OSError:
return False
LOGGING_NAME = 'yolov5'
def set_logging(name=LOGGING_NAME, verbose=True):
# sets up logging for the given name
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
name: {
'format': '%(message)s'}},
'handlers': {
name: {
'class': 'logging.StreamHandler',
'formatter': name,
'level': level, }},
'loggers': {
name: {
'level': level,
'handlers': [name],
'propagate': False, }}})
set_logging(LOGGING_NAME) # run before defining LOGGER
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
if platform.system() == 'Windows':
for fn in LOGGER.info, LOGGER.warning:
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
env = os.getenv(env_var)
if env:
path = Path(env) # use environment variable
else:
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
path.mkdir(exist_ok=True) # make if required
return path
CONFIG_DIR = user_config_dir() # Ultralytics settings dir
class Profile(contextlib.ContextDecorator):
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
def __init__(self, t=0.0, device: torch.device = None):
self.t = t
self.device = device
self.cuda = True if (device and str(device)[:4] == 'cuda') else False
def __enter__(self):
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
if self.cuda:
torch.cuda.synchronize(self.device)
return time.time()
class Timeout(contextlib.ContextDecorator):
# YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
self.seconds = int(seconds)
self.timeout_message = timeout_msg
self.suppress = bool(suppress_timeout_errors)
def _timeout_handler(self, signum, frame):
raise TimeoutError(self.timeout_message)
def __enter__(self):
if platform.system() != 'Windows': # not supported on Windows
signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
def __exit__(self, exc_type, exc_val, exc_tb):
if platform.system() != 'Windows':
signal.alarm(0) # Cancel SIGALRM if it's scheduled
if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
return True
class WorkingDirectory(contextlib.ContextDecorator):
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
def __init__(self, new_dir):
self.dir = new_dir # new dir
self.cwd = Path.cwd().resolve() # current dir
def __enter__(self):
os.chdir(self.dir)
def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.cwd)
def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith('__')]
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
except ValueError:
file = Path(file).stem
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
def init_seeds(seed=0, deterministic=False):
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
def get_default_args(func):
# Get func() default arguments
signature = inspect.signature(func)
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''
def file_age(path=__file__):
# Return days since last file update
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
# Return human-readable file modification date, i.e. '2021-3-26'
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'
def file_size(path):
# Return file/dir size (MB)
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
else:
return 0.0
def check_online():
# Check internet connectivity
import socket
def run_once():
# Check once
try:
socket.create_connection(('1.1.1.1', 443), 5) # check host accessibility
return True
except OSError:
return False
return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
try:
assert (Path(path) / '.git').is_dir()
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
except Exception:
return ''
@TryExcept()
@WorkingDirectory(ROOT)
def check_git_status(repo='ultralytics/yolov5', branch='master'):
# YOLOv5 status check, recommend 'git pull' if code is out of date
url = f'https://github.com/{repo}'
msg = f', for updates see {url}'
s = colorstr('github: ') # string
assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
assert check_online(), s + 'skipping check (offline)' + msg
splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
matches = [repo in s for s in splits]
if any(matches):
remote = splits[matches.index(True) - 1]
else:
remote = 'ultralytics'
check_output(f'git remote add {remote} {url}', shell=True)
check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
if n > 0:
pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use '{pull}' or 'git clone {url}' to update."
else:
s += f'up to date with {url} ✅'
LOGGER.info(s)
@WorkingDirectory(ROOT)
def check_git_info(path='.'):
# YOLOv5 git info check, return {remote, branch, commit}
check_requirements('gitpython')
import git
try:
repo = git.Repo(path)
remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/ultralytics/yolov5'
commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
try:
branch = repo.active_branch.name # i.e. 'main'
except TypeError: # not on any branch
branch = None # i.e. 'detached HEAD' state
return {'remote': remote, 'branch': branch, 'commit': commit}
except git.exc.InvalidGitRepositoryError: # path is not a git dir
return {'remote': None, 'branch': None, 'commit': None}
def check_python(minimum='3.8.0'):
# Check current python version vs. required python version
check_version(platform.python_version(), minimum, name='Python ', hard=True)
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
# Check version vs. required version
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum) # bool
s = f'WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed' # string
if hard:
assert result, emojis(s) # assert min requirements met
if verbose and not result:
LOGGER.warning(s)
return result
def check_img_size(imgsz, s=32, floor=0):
# Verify image size is a multiple of stride s in each dimension
if isinstance(imgsz, int): # integer i.e. img_size=640
new_size = max(make_divisible(imgsz, int(s)), floor)
else: # list i.e. img_size=[640, 480]
imgsz = list(imgsz) # convert to list if tuple
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
def check_imshow(warn=False):
# Check if environment supports image displays
try:
assert not is_jupyter()
assert not is_docker()
cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
return False
def check_suffix(file='yolov5s.pt', suffix=('.pt', ), msg=''):
# Check file(s) for acceptable suffix
if file and suffix:
if isinstance(suffix, str):
suffix = [suffix]
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix
if len(s):
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
def check_yaml(file, suffix=('.yaml', '.yml')):
# Search/download YAML file (if necessary) and return path, checking suffix
return check_file(file, suffix)
def check_file(file, suffix=''):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to str()
if os.path.isfile(file) or not file: # exists
return file
elif file.startswith(('http:/', 'https:/')): # download
url = file # warning: Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
if os.path.isfile(file):
LOGGER.info(f'Found {url} locally at {file}') # file already exists
else:
LOGGER.info(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
return file
elif file.startswith('clearml://'): # ClearML Dataset ID
assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
return file
else: # search
files = []
for d in 'data', 'models', 'utils': # search directories
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
assert len(files), f'File not found: {file}' # assert file was found
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
return files[0] # return file
def check_font(font=FONT, progress=False):
# Download font to CONFIG_DIR if necessary
font = Path(font)
file = CONFIG_DIR / font.name
if not font.exists() and not file.exists():
url = f'https://ultralytics.com/assets/{font.name}'
LOGGER.info(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file), progress=progress)
def check_dataset(data, autodownload=True):
# Download, check and/or unzip dataset if not found locally
# Download (optional)
extract_dir = ''
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False
# Read yaml (optional)
if isinstance(data, (str, Path)):
data = yaml_load(data) # dictionary
# Checks
for k in 'train', 'val', 'names':
assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
if isinstance(data['names'], (list, tuple)): # old array format
data['names'] = dict(enumerate(data['names'])) # convert to dict
assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
data['nc'] = len(data['names'])
# Resolve paths
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
if not path.is_absolute():
path = (ROOT / path).resolve()
data['path'] = path # download scripts
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
if not x.exists() and data[k].startswith('../'):
x = (path / data[k][3:]).resolve()
data[k] = str(x)
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse yaml
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
if not s or not autodownload:
raise Exception('Dataset not found ❌')
t = time.time()
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
LOGGER.info(f'Downloading {s} to {f}...')
torch.hub.download_url_to_file(s, f)
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
unzip_file(f, path=DATASETS_DIR) # unzip
Path(f).unlink() # remove zip
r = None # success
elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...')
r = subprocess.run(s, shell=True)
else: # python script
r = exec(s, {'yaml': data}) # return None
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
LOGGER.info(f'Dataset download {s}')
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
return data # dictionary
def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape, DetectMultiBackend
def amp_allclose(model, im):
# All close FP32 vs AMP results
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
b = m(im).xywhn[0] # AMP inference
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
prefix = colorstr('AMP: ')
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
return False # AMP only used on CUDA devices
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
try:
assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
LOGGER.info(f'{prefix}checks passed ✅')
return True
except Exception:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
return False
def yaml_load(file='data.yaml'):
# Single-line safe yaml loading
with open(file, errors='ignore') as f:
return yaml.safe_load(f)
def yaml_save(file='data.yaml', data={}):
# Single-line safe yaml saving
with open(file, 'w') as f:
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
# Unzip a *.zip file to path/, excluding files containing strings in exclude list
if path is None:
path = Path(file).parent # default path
with ZipFile(file) as zipObj:
for f in zipObj.namelist(): # list all archived filenames in the zip
if all(x not in f for x in exclude):
zipObj.extract(f, path=path)
def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
# Multithreaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir):
# Download 1 file
success = True
if os.path.isfile(url):
f = Path(url) # filename
else: # does not exist
f = dir / Path(url).name
LOGGER.info(f'Downloading {url} to {f}...')
for i in range(retry + 1):
if curl:
success = curl_download(url, f, silent=(threads > 1))
else:
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
success = f.is_file()
if success:
break
elif i < retry:
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
else:
LOGGER.warning(f'❌ Failed to download {url}...')
if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
LOGGER.info(f'Unzipping {f}...')
if is_zipfile(f):
unzip_file(f, dir) # unzip
elif is_tarfile(f):
subprocess.run(['tar', 'xf', f, '--directory', f.parent], check=True) # unzip
elif f.suffix == '.gz':
subprocess.run(['tar', 'xfz', f, '--directory', f.parent], check=True) # unzip
if delete:
f.unlink() # remove zip
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
pool = ThreadPool(threads)
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
pool.close()
pool.join()
else:
for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir)
def make_divisible(x, divisor):
# Returns nearest x divisible by divisor
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor
def clean_str(s):
# Cleans a string by replacing special characters with underscore _
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
colors = {
'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded
return torch.Tensor()
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(int) # labels = [class xywh]
weights = np.bincount(classes, minlength=nc) # occurrences per class
# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.from_numpy(weights).float()
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
return (class_weights.reshape(1, nc) * class_counts).sum(1)
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
y[..., 2] = x[..., 2] - x[..., 0] # width
y[..., 3] = x[..., 3] - x[..., 1] # height
return y
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
return y
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
return y
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
if clip:
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
return y
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# Convert normalized segments into pixel segments, shape (n,2)
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = w * x[..., 0] + padw # top left x
y[..., 1] = h * x[..., 1] + padh # top left y
return y
def segment2box(segment, width=640, height=640):
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x, y, = x[inside], y[inside]
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
def segments2boxes(segments):
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
def resample_segments(segments, n=1000):
# Up-sample an (n,2) segment
for i, s in enumerate(segments):
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
return segments
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
# Rescale boxes (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding
boxes[..., :4] /= gain
clip_boxes(boxes, img0_shape)
return boxes
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
segments[:, 0] -= pad[0] # x padding
segments[:, 1] -= pad[1] # y padding
segments /= gain
clip_segments(segments, img0_shape)
if normalize:
segments[:, 0] /= img0_shape[1] # width
segments[:, 1] /= img0_shape[0] # height
return segments
def clip_boxes(boxes, shape):
# Clip boxes (xyxy) to image shape (height, width)
if isinstance(boxes, torch.Tensor): # faster individually
boxes[..., 0].clamp_(0, shape[1]) # x1
boxes[..., 1].clamp_(0, shape[0]) # y1
boxes[..., 2].clamp_(0, shape[1]) # x2
boxes[..., 3].clamp_(0, shape[0]) # y2
else: # np.array (faster grouped)
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
def clip_segments(segments, shape):
# Clip segments (xy1,xy2,...) to image shape (height, width)
if isinstance(segments, torch.Tensor): # faster individually
segments[:, 0].clamp_(0, shape[1]) # x
segments[:, 1].clamp_(0, shape[0]) # y
else: # np.array (faster grouped)
segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
mi = 5 + nc # mask start index
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box/Mask
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, mi:] # zero columns if no masks
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = x[:, 5:mi].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
evolve_csv = save_dir / 'evolve.csv'
evolve_yaml = save_dir / 'hyp_evolve.yaml'
keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
keys = tuple(x.strip() for x in keys)
vals = results + tuple(hyp.values())
n = len(keys)
# Download (optional)
if bucket:
url = f'gs://{bucket}/evolve.csv'
if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
subprocess.run(['gsutil', 'cp', f'{url}', f'{save_dir}']) # download evolve.csv if larger than local
# Log to evolve.csv
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
with open(evolve_csv, 'a') as f:
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
# Save yaml
with open(evolve_yaml, 'w') as f:
data = pd.read_csv(evolve_csv, skipinitialspace=True)
data = data.rename(columns=lambda x: x.strip()) # strip keys
i = np.argmax(fitness(data.values[:, :4])) #
generations = len(data)
f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
'\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
# Print to screen
LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
for x in vals) + '\n\n')
if bucket:
subprocess.run(['gsutil', 'cp', f'{evolve_csv}', f'{evolve_yaml}', f'gs://{bucket}']) # upload
def apply_classifier(x, model, img, im0):
# Apply a second stage classifier to YOLO outputs
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
im0 = [im0] if isinstance(im0, np.ndarray) else im0
for i, d in enumerate(x): # per image
if d is not None and len(d):
d = d.clone()
# Reshape and pad cutouts
b = xyxy2xywh(d[:, :4]) # boxes
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
d[:, :4] = xywh2xyxy(b).long()
# Rescale boxes from img_size to im0 size
scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
# Classes
pred_cls1 = d[:, 5].long()
ims = []
for a in d:
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
im = cv2.resize(cutout, (224, 224)) # BGR
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
im /= 255 # 0 - 255 to 0.0 - 1.0
ims.append(im)
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
return x
def increment_path(path, exist_ok=False, sep='', mkdir=False):
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
# Method 1
for n in range(2, 9999):
p = f'{path}{sep}{n}{suffix}' # increment path
if not os.path.exists(p): #
break
path = Path(p)
# Method 2 (deprecated)
# dirs = glob.glob(f"{path}{sep}*") # similar paths
# matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
# i = [int(m.groups()[0]) for m in matches if m] # indices
# n = max(i) + 1 if i else 2 # increment number
# path = Path(f"{path}{sep}{n}{suffix}") # increment path
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------------
imshow_ = cv2.imshow # copy to avoid recursion errors
def imread(filename, flags=cv2.IMREAD_COLOR):
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
def imwrite(filename, img):
try:
cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
return True
except Exception:
return False
def imshow(path, im):
imshow_(path.encode('unicode_escape').decode(), im)
if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
def get_object_level_feature_maps(feature_map, targets):
feature_map_shape = feature_map.shape[2:]
# Assuming targets contain batch, class, x_center, y_center, width, height
x_center = targets[:, 2] * feature_map_shape[1]
y_center = targets[:, 3] * feature_map_shape[0]
width = targets[:, 4] * feature_map_shape[1]
height = targets[:, 5] * feature_map_shape[0]
# Calculate pixel coordinates for the bounding boxes
x_min = torch.clamp((x_center - width / 2).int(), 0, feature_map_shape[1] - 1)
y_min = torch.clamp((y_center - height / 2).int(), 0, feature_map_shape[0] - 1)
x_max = torch.clamp((x_center + width / 2).int(), 0, feature_map_shape[1] - 1)
y_max = torch.clamp((y_center + height / 2).int(), 0, feature_map_shape[0] - 1)
# Extract regions from the feature_map based on the bounding boxes
extracted_regions = [feature_map[:, :, y_min[i]:y_max[i] + 1, x_min[i]:x_max[i] + 1] for i in range(targets.shape[0])]
return extracted_regions
def get_object_level_feature_maps2(feature_map, targets):
feature_map_shape = feature_map.shape[2:]
# Assuming targets contain batch, class, x_center, y_center, width, height
x_center = targets[:, 1] * feature_map_shape[1]
y_center = targets[:, 2] * feature_map_shape[0]
width = targets[:, 3] * feature_map_shape[1]
height = targets[:, 4] * feature_map_shape[0]
# Calculate pixel coordinates for the bounding boxes
x_min = torch.clamp((x_center - width / 2).int(), 0, feature_map_shape[1] - 1)
y_min = torch.clamp((y_center - height / 2).int(), 0, feature_map_shape[0] - 1)
x_max = torch.clamp((x_center + width / 2).int(), 0, feature_map_shape[1] - 1)
y_max = torch.clamp((y_center + height / 2).int(), 0, feature_map_shape[0] - 1)
# Extract regions from the feature_map based on the bounding boxes
extracted_regions = [feature_map[:, :, y_min[i]:y_max[i] + 1, x_min[i]:x_max[i] + 1] for i in range(targets.shape[0])]
return extracted_regions
def extract_roi_features(concatenated_features, resize_boxes):
"""
Extracts regions of interest (ROIs) from the concatenated_features based on resize_boxes.
Args:
concatenated_features (torch.Tensor): Feature map with shape [batch, channels, height, width].
resize_boxes (torch.Tensor): Boxes with shape [num_boxes, 5], where each row is [batch, x1, y1, x2, y2].
Returns:
torch.Tensor: Tensor containing the ROI features with shape [num_boxes, channels, roi_height, roi_width].
"""
# Initialize a list to store ROI features for each box
roi_features_list = []
for box_idx in range(resize_boxes.size(0)):
# Extract box coordinates
box_coords = resize_boxes[box_idx, 1:]
# Calculate the spatial coordinates of the box
box_x1, box_y1, box_x2, box_y2 = box_coords
roi_x1 = (box_x1 / concatenated_features.size(3)) * concatenated_features.size(3)
roi_y1 = (box_y1 / concatenated_features.size(2)) * concatenated_features.size(2)
roi_x2 = (box_x2 / concatenated_features.size(3)) * concatenated_features.size(3)
roi_y2 = (box_y2 / concatenated_features.size(2)) * concatenated_features.size(2)
# Convert to integer indices
roi_x1, roi_y1, roi_x2, roi_y2 = map(int, [roi_x1, roi_y1, roi_x2, roi_y2])
# Extract ROI from the feature map
roi_features = concatenated_features[:, :, roi_y1:roi_y2, roi_x1:roi_x2]
# Append the ROI features to the list
roi_features_list.append(roi_features)
return roi_features_list
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
def plot_multi_channel_feature_map_with_boxes(feature_map, boxes, channels, title, save_path=None):
fig, axs = plt.subplots(1, len(channels) + 1, figsize=(12, 4))
for i, channel in enumerate(channels):
axs[i].imshow(feature_map[0, channel].cpu().detach().numpy(), cmap='viridis', aspect='auto')
axs[i].set_title(f'Channel {channel}')
# Plot bounding boxes on the last axis
axs[-1].imshow(feature_map[0, channels[-1]].cpu().detach().numpy(), cmap='viridis', aspect='auto')
axs[-1].set_title('Bounding Boxes')
for box in range(len(boxes.shape)):
xmin, ymin, xmax, ymax = boxes.cpu().detach().numpy()
rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor='r', facecolor='none')
axs[-1].add_patch(rect)
fig.suptitle(title)
# Save the image if save_path is provided
if save_path:
plt.savefig(save_path)
print(f"Image saved at: {save_path}")
else:
plt.show()
# Denormalize the box
def xywh_to_xyxy(xywh):
x_center, y_center, width, height = xywh
x_min = x_center - width / 2
y_min = y_center - height / 2
x_max = x_center + width / 2
y_max = y_center + height / 2
return torch.tensor([x_min, y_min, x_max, y_max])
def get_fixed_xyxy(normalized_xyxy,int_feat):
x_min, y_min, x_max, y_max = normalized_xyxy.int()
if x_min == x_max:
x_max += 1
if y_min == y_max:
y_max += 1
if x_min == x_max and x_max == int_feat.size(2):
x_min -= 1
if y_min == y_max and y_max == int_feat.size(1):
y_min -= 1
return x_min, y_min, x_max, y_max
# Variables ------------------------------------------------------------------------------------------------------------