Diff of /utils/dataloaders.py [000000] .. [190ca4]

Switch to unified view

a b/utils/dataloaders.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Dataloaders and dataset utils
4
"""
5
6
import contextlib
7
import glob
8
import hashlib
9
import json
10
import math
11
import os
12
import random
13
import shutil
14
import time
15
from itertools import repeat
16
from multiprocessing.pool import Pool, ThreadPool
17
from pathlib import Path
18
from threading import Thread
19
from urllib.parse import urlparse
20
21
import numpy as np
22
import psutil
23
import torch
24
import torch.nn.functional as F
25
import torchvision
26
import yaml
27
from PIL import ExifTags, Image, ImageOps
28
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
29
from tqdm import tqdm
30
31
from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
32
                                 letterbox, mixup, random_perspective)
33
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements,
34
                           check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy,
35
                           xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
36
from utils.torch_utils import torch_distributed_zero_first
37
38
# Parameters
39
HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
40
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes
41
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv'  # include video suffixes
42
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
43
RANK = int(os.getenv('RANK', -1))
44
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
45
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true'  # global pin_memory for dataloaders
46
47
# Get orientation exif tag
48
for orientation in ExifTags.TAGS.keys():
49
    if ExifTags.TAGS[orientation] == 'Orientation':
50
        break
51
52
53
def get_hash(paths):
54
    # Returns a single hash value of a list of paths (files or dirs)
55
    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
56
    h = hashlib.sha256(str(size).encode())  # hash sizes
57
    h.update(''.join(paths).encode())  # hash paths
58
    return h.hexdigest()  # return hash
59
60
61
def exif_size(img):
62
    # Returns exif-corrected PIL size
63
    s = img.size  # (width, height)
64
    with contextlib.suppress(Exception):
65
        rotation = dict(img._getexif().items())[orientation]
66
        if rotation in [6, 8]:  # rotation 270 or 90
67
            s = (s[1], s[0])
68
    return s
69
70
71
def exif_transpose(image):
72
    """
73
    Transpose a PIL image accordingly if it has an EXIF Orientation tag.
74
    Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
75
76
    :param image: The image to transpose.
77
    :return: An image.
78
    """
79
    exif = image.getexif()
80
    orientation = exif.get(0x0112, 1)  # default 1
81
    if orientation > 1:
82
        method = {
83
            2: Image.FLIP_LEFT_RIGHT,
84
            3: Image.ROTATE_180,
85
            4: Image.FLIP_TOP_BOTTOM,
86
            5: Image.TRANSPOSE,
87
            6: Image.ROTATE_270,
88
            7: Image.TRANSVERSE,
89
            8: Image.ROTATE_90}.get(orientation)
90
        if method is not None:
91
            image = image.transpose(method)
92
            del exif[0x0112]
93
            image.info['exif'] = exif.tobytes()
94
    return image
95
96
97
def seed_worker(worker_id):
98
    # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
99
    worker_seed = torch.initial_seed() % 2 ** 32
100
    np.random.seed(worker_seed)
101
    random.seed(worker_seed)
102
103
104
# Inherit from DistributedSampler and override iterator
105
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
106
class SmartDistributedSampler(distributed.DistributedSampler):
107
108
    def __iter__(self):
109
        # deterministically shuffle based on epoch and seed
110
        g = torch.Generator()
111
        g.manual_seed(self.seed + self.epoch)
112
113
        # determine the the eventual size (n) of self.indices (DDP indices)
114
        n = int((len(self.dataset) - self.rank - 1) / self.num_replicas) + 1  # num_replicas == WORLD_SIZE
115
        idx = torch.randperm(n, generator=g)
116
        if not self.shuffle:
117
            idx = idx.sort()[0]
118
119
        idx = idx.tolist()
120
        if self.drop_last:
121
            idx = idx[:self.num_samples]
122
        else:
123
            padding_size = self.num_samples - len(idx)
124
            if padding_size <= len(idx):
125
                idx += idx[:padding_size]
126
            else:
127
                idx += (idx * math.ceil(padding_size / len(idx)))[:padding_size]
128
129
        return iter(idx)
130
131
132
def create_dataloader(path,
133
                      imgsz,
134
                      batch_size,
135
                      stride,
136
                      single_cls=False,
137
                      hyp=None,
138
                      augment=False,
139
                      cache=False,
140
                      pad=0.0,
141
                      rect=False,
142
                      rank=-1,
143
                      workers=8,
144
                      image_weights=False,
145
                      quad=False,
146
                      prefix='',
147
                      shuffle=False,
148
                      seed=0):
149
    if rect and shuffle:
150
        LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
151
        shuffle = False
152
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
153
        dataset = LoadImagesAndLabels(
154
            path,
155
            imgsz,
156
            batch_size,
157
            augment=augment,  # augmentation
158
            hyp=hyp,  # hyperparameters
159
            rect=rect,  # rectangular batches
160
            cache_images=cache,
161
            single_cls=single_cls,
162
            stride=int(stride),
163
            pad=pad,
164
            image_weights=image_weights,
165
            prefix=prefix,
166
            rank=rank)
167
168
    batch_size = min(batch_size, len(dataset))
169
    nd = torch.cuda.device_count()  # number of CUDA devices
170
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workers
171
    sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
172
    loader = DataLoader if image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updates
173
    generator = torch.Generator()
174
    generator.manual_seed(6148914691236517205 + seed + RANK)
175
    return loader(dataset,
176
                  batch_size=batch_size,
177
                  shuffle=shuffle and sampler is None,
178
                  num_workers=nw,
179
                  sampler=sampler,
180
                  pin_memory=PIN_MEMORY,
181
                  collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
182
                  worker_init_fn=seed_worker,
183
                  generator=generator), dataset
184
185
186
class InfiniteDataLoader(dataloader.DataLoader):
187
    """ Dataloader that reuses workers
188
189
    Uses same syntax as vanilla DataLoader
190
    """
191
192
    def __init__(self, *args, **kwargs):
193
        super().__init__(*args, **kwargs)
194
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
195
        self.iterator = super().__iter__()
196
197
    def __len__(self):
198
        return len(self.batch_sampler.sampler)
199
200
    def __iter__(self):
201
        for _ in range(len(self)):
202
            yield next(self.iterator)
203
204
205
class _RepeatSampler:
206
    """ Sampler that repeats forever
207
208
    Args:
209
        sampler (Sampler)
210
    """
211
212
    def __init__(self, sampler):
213
        self.sampler = sampler
214
215
    def __iter__(self):
216
        while True:
217
            yield from iter(self.sampler)
218
219
220
class LoadScreenshots:
221
    # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
222
    def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
223
        # source = [screen_number left top width height] (pixels)
224
        check_requirements('mss')
225
        import mss
226
227
        source, *params = source.split()
228
        self.screen, left, top, width, height = 0, None, None, None, None  # default to full screen 0
229
        if len(params) == 1:
230
            self.screen = int(params[0])
231
        elif len(params) == 4:
232
            left, top, width, height = (int(x) for x in params)
233
        elif len(params) == 5:
234
            self.screen, left, top, width, height = (int(x) for x in params)
235
        self.img_size = img_size
236
        self.stride = stride
237
        self.transforms = transforms
238
        self.auto = auto
239
        self.mode = 'stream'
240
        self.frame = 0
241
        self.sct = mss.mss()
242
243
        # Parse monitor shape
244
        monitor = self.sct.monitors[self.screen]
245
        self.top = monitor['top'] if top is None else (monitor['top'] + top)
246
        self.left = monitor['left'] if left is None else (monitor['left'] + left)
247
        self.width = width or monitor['width']
248
        self.height = height or monitor['height']
249
        self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
250
251
    def __iter__(self):
252
        return self
253
254
    def __next__(self):
255
        # mss screen capture: get raw pixels from the screen as np array
256
        im0 = np.array(self.sct.grab(self.monitor))[:, :, :3]  # [:, :, :3] BGRA to BGR
257
        s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
258
259
        if self.transforms:
260
            im = self.transforms(im0)  # transforms
261
        else:
262
            im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
263
            im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
264
            im = np.ascontiguousarray(im)  # contiguous
265
        self.frame += 1
266
        return str(self.screen), im, im0, None, s  # screen, img, original img, im0s, s
267
268
269
class LoadImages:
270
    # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
271
    def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
272
        if isinstance(path, str) and Path(path).suffix == '.txt':  # *.txt file with img/vid/dir on each line
273
            path = Path(path).read_text().rsplit()
274
        files = []
275
        for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
276
            p = str(Path(p).resolve())
277
            if '*' in p:
278
                files.extend(sorted(glob.glob(p, recursive=True)))  # glob
279
            elif os.path.isdir(p):
280
                files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))  # dir
281
            elif os.path.isfile(p):
282
                files.append(p)  # files
283
            else:
284
                raise FileNotFoundError(f'{p} does not exist')
285
286
        images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
287
        videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
288
        ni, nv = len(images), len(videos)
289
290
        self.img_size = img_size
291
        self.stride = stride
292
        self.files = images + videos
293
        self.nf = ni + nv  # number of files
294
        self.video_flag = [False] * ni + [True] * nv
295
        self.mode = 'image'
296
        self.auto = auto
297
        self.transforms = transforms  # optional
298
        self.vid_stride = vid_stride  # video frame-rate stride
299
        if any(videos):
300
            self._new_video(videos[0])  # new video
301
        else:
302
            self.cap = None
303
        assert self.nf > 0, f'No images or videos found in {p}. ' \
304
                            f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
305
306
    def __iter__(self):
307
        self.count = 0
308
        return self
309
310
    def __next__(self):
311
        if self.count == self.nf:
312
            raise StopIteration
313
        path = self.files[self.count]
314
315
        if self.video_flag[self.count]:
316
            # Read video
317
            self.mode = 'video'
318
            for _ in range(self.vid_stride):
319
                self.cap.grab()
320
            ret_val, im0 = self.cap.retrieve()
321
            while not ret_val:
322
                self.count += 1
323
                self.cap.release()
324
                if self.count == self.nf:  # last video
325
                    raise StopIteration
326
                path = self.files[self.count]
327
                self._new_video(path)
328
                ret_val, im0 = self.cap.read()
329
330
            self.frame += 1
331
            # im0 = self._cv2_rotate(im0)  # for use if cv2 autorotation is False
332
            s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
333
334
        else:
335
            # Read image
336
            self.count += 1
337
            im0 = cv2.imread(path)  # BGR
338
            self.orig_img = im0.copy()
339
            pil_img = Image.fromarray(im0)
340
            image = pil_img.resize(self.img_size, Image.LANCZOS)
341
            im0 = np.array(image)
342
            assert im0 is not None, f'Image Not Found {path}'
343
            s = f'image {self.count}/{self.nf} {path}: '
344
345
        if self.transforms:
346
            im = self.transforms(im0)  # transforms
347
        else:
348
            im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
349
            im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
350
            im = np.ascontiguousarray(im)  # contiguous
351
352
        return path, im, im0, self.cap, s, self.orig_img
353
354
    def _new_video(self, path):
355
        # Create a new video capture object
356
        self.frame = 0
357
        self.cap = cv2.VideoCapture(path)
358
        self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
359
        self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META))  # rotation degrees
360
        # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)  # disable https://github.com/ultralytics/yolov5/issues/8493
361
362
    def _cv2_rotate(self, im):
363
        # Rotate a cv2 video manually
364
        if self.orientation == 0:
365
            return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
366
        elif self.orientation == 180:
367
            return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
368
        elif self.orientation == 90:
369
            return cv2.rotate(im, cv2.ROTATE_180)
370
        return im
371
372
    def __len__(self):
373
        return self.nf  # number of files
374
375
376
class LoadStreams:
377
    # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP streams`
378
    def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
379
        torch.backends.cudnn.benchmark = True  # faster for fixed-size inference
380
        self.mode = 'stream'
381
        self.img_size = img_size
382
        self.stride = stride
383
        self.vid_stride = vid_stride  # video frame-rate stride
384
        sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
385
        n = len(sources)
386
        self.sources = [clean_str(x) for x in sources]  # clean source names for later
387
        self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
388
        for i, s in enumerate(sources):  # index, source
389
            # Start thread to read frames from video stream
390
            st = f'{i + 1}/{n}: {s}... '
391
            if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'):  # if source is YouTube video
392
                # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
393
                check_requirements(('pafy', 'youtube_dl==2020.12.2'))
394
                import pafy
395
                s = pafy.new(s).getbest(preftype='mp4').url  # YouTube URL
396
            s = eval(s) if s.isnumeric() else s  # i.e. s = '0' local webcam
397
            if s == 0:
398
                assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
399
                assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
400
            cap = cv2.VideoCapture(s)
401
            assert cap.isOpened(), f'{st}Failed to open {s}'
402
            w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
403
            h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
404
            fps = cap.get(cv2.CAP_PROP_FPS)  # warning: may return 0 or nan
405
            self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf')  # infinite stream fallback
406
            self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30  # 30 FPS fallback
407
408
            _, self.imgs[i] = cap.read()  # guarantee first frame
409
            self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
410
            LOGGER.info(f'{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)')
411
            self.threads[i].start()
412
        LOGGER.info('')  # newline
413
414
        # check for common shapes
415
        s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
416
        self.rect = np.unique(s, axis=0).shape[0] == 1  # rect inference if all shapes equal
417
        self.auto = auto and self.rect
418
        self.transforms = transforms  # optional
419
        if not self.rect:
420
            LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
421
422
    def update(self, i, cap, stream):
423
        # Read stream `i` frames in daemon thread
424
        n, f = 0, self.frames[i]  # frame number, frame array
425
        while cap.isOpened() and n < f:
426
            n += 1
427
            cap.grab()  # .read() = .grab() followed by .retrieve()
428
            if n % self.vid_stride == 0:
429
                success, im = cap.retrieve()
430
                if success:
431
                    self.imgs[i] = im
432
                else:
433
                    LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
434
                    self.imgs[i] = np.zeros_like(self.imgs[i])
435
                    cap.open(stream)  # re-open stream if signal was lost
436
            time.sleep(0.0)  # wait time
437
438
    def __iter__(self):
439
        self.count = -1
440
        return self
441
442
    def __next__(self):
443
        self.count += 1
444
        if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'):  # q to quit
445
            cv2.destroyAllWindows()
446
            raise StopIteration
447
448
        im0 = self.imgs.copy()
449
        if self.transforms:
450
            im = np.stack([self.transforms(x) for x in im0])  # transforms
451
        else:
452
            im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0])  # resize
453
            im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW
454
            im = np.ascontiguousarray(im)  # contiguous
455
456
        return self.sources, im, im0, None, ''
457
458
    def __len__(self):
459
        return len(self.sources)  # 1E12 frames = 32 streams at 30 FPS for 30 years
460
461
462
def img2label_paths(img_paths):
463
    # Define label paths as a function of image paths
464
    sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}'  # /images/, /labels/ substrings
465
    return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
466
467
468
class LoadImagesAndLabels(Dataset):
469
    # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
470
    cache_version = 0.6  # dataset labels *.cache version
471
    rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
472
473
    def __init__(self,
474
                 path,
475
                 img_size=640,
476
                 batch_size=16,
477
                 augment=False,
478
                 hyp=None,
479
                 rect=False,
480
                 image_weights=False,
481
                 cache_images=False,
482
                 single_cls=False,
483
                 stride=32,
484
                 pad=0.0,
485
                 min_items=0,
486
                 prefix='',
487
                 rank=-1,
488
                 seed=0):
489
        self.img_size = img_size
490
        self.augment = augment
491
        self.hyp = hyp
492
        self.image_weights = image_weights
493
        self.rect = False if image_weights else rect
494
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
495
        self.mosaic_border = [-img_size // 2, -img_size // 2]
496
        self.stride = stride
497
        self.path = path
498
        self.albumentations = Albumentations(size=img_size) if augment else None
499
500
        try:
501
            f = []  # image files
502
            for p in path if isinstance(path, list) else [path]:
503
                p = Path(p)  # os-agnostic
504
                if p.is_dir():  # dir
505
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
506
                    # f = list(p.rglob('*.*'))  # pathlib
507
                elif p.is_file():  # file
508
                    with open(p) as t:
509
                        t = t.read().strip().splitlines()
510
                        parent = str(p.parent) + os.sep
511
                        f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t]  # to global path
512
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # to global path (pathlib)
513
                else:
514
                    raise FileNotFoundError(f'{prefix}{p} does not exist')
515
            self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
516
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
517
            assert self.im_files, f'{prefix}No images found'
518
        except Exception as e:
519
            raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
520
521
        # Check cache
522
        self.label_files = img2label_paths(self.im_files)  # labels
523
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
524
        try:
525
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
526
            assert cache['version'] == self.cache_version  # matches current version
527
            assert cache['hash'] == get_hash(self.label_files + self.im_files)  # identical hash
528
        except Exception:
529
            cache, exists = self.cache_labels(cache_path, prefix), False  # run cache ops
530
531
        # Display cache
532
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, total
533
        if exists and LOCAL_RANK in {-1, 0}:
534
            d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
535
            tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)  # display cache results
536
            if cache['msgs']:
537
                LOGGER.info('\n'.join(cache['msgs']))  # display warnings
538
        assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
539
540
        # Read cache
541
        [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items
542
        labels, shapes, self.segments = zip(*cache.values())
543
        nl = len(np.concatenate(labels, 0))  # number of labels
544
        assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
545
        self.labels = list(labels)
546
        self.shapes = np.array(shapes)
547
        self.im_files = list(cache.keys())  # update
548
        self.label_files = img2label_paths(cache.keys())  # update
549
550
        # Filter images
551
        if min_items:
552
            include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
553
            LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
554
            self.im_files = [self.im_files[i] for i in include]
555
            self.label_files = [self.label_files[i] for i in include]
556
            self.labels = [self.labels[i] for i in include]
557
            self.segments = [self.segments[i] for i in include]
558
            self.shapes = self.shapes[include]  # wh
559
560
        # Create indices
561
        n = len(self.shapes)  # number of images
562
        bi = np.floor(np.arange(n) / batch_size).astype(int)  # batch index
563
        nb = bi[-1] + 1  # number of batches
564
        self.batch = bi  # batch index of image
565
        self.n = n
566
        self.indices = np.arange(n)
567
        if rank > -1:  # DDP indices (see: SmartDistributedSampler)
568
            # force each rank (i.e. GPU process) to sample the same subset of data on every epoch
569
            self.indices = self.indices[np.random.RandomState(seed=seed).permutation(n) % WORLD_SIZE == RANK]
570
571
        # Update labels
572
        include_class = []  # filter labels to include only these classes (optional)
573
        self.segments = list(self.segments)
574
        include_class_array = np.array(include_class).reshape(1, -1)
575
        for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
576
            if include_class:
577
                j = (label[:, 0:1] == include_class_array).any(1)
578
                self.labels[i] = label[j]
579
                if segment:
580
                    self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem]
581
            if single_cls:  # single-class training, merge all classes into 0
582
                self.labels[i][:, 0] = 0
583
584
        # Rectangular Training
585
        if self.rect:
586
            # Sort by aspect ratio
587
            s = self.shapes  # wh
588
            ar = s[:, 1] / s[:, 0]  # aspect ratio
589
            irect = ar.argsort()
590
            self.im_files = [self.im_files[i] for i in irect]
591
            self.label_files = [self.label_files[i] for i in irect]
592
            self.labels = [self.labels[i] for i in irect]
593
            self.segments = [self.segments[i] for i in irect]
594
            self.shapes = s[irect]  # wh
595
            ar = ar[irect]
596
597
            # Set training image shapes
598
            shapes = [[1, 1]] * nb
599
            for i in range(nb):
600
                ari = ar[bi == i]
601
                mini, maxi = ari.min(), ari.max()
602
                if maxi < 1:
603
                    shapes[i] = [maxi, 1]
604
                elif mini > 1:
605
                    shapes[i] = [1, 1 / mini]
606
607
            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
608
609
        # Cache images into RAM/disk for faster training
610
        if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
611
            cache_images = False
612
        self.ims = [None] * n
613
        self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
614
        if cache_images:
615
            b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
616
            self.im_hw0, self.im_hw = [None] * n, [None] * n
617
            fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
618
            results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)), self.indices)
619
            pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
620
            for i, x in pbar:
621
                if cache_images == 'disk':
622
                    b += self.npy_files[i].stat().st_size
623
                else:  # 'ram'
624
                    self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
625
                    b += self.ims[i].nbytes * WORLD_SIZE
626
                pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
627
            pbar.close()
628
629
    def check_cache_ram(self, safety_margin=0.1, prefix=''):
630
        # Check image caching requirements vs available memory
631
        b, gb = 0, 1 << 30  # bytes of cached images, bytes per gigabytes
632
        n = min(self.n, 30)  # extrapolate from 30 random images
633
        for _ in range(n):
634
            im = cv2.imread(random.choice(self.im_files))  # sample image
635
            ratio = self.img_size / max(im.shape[0], im.shape[1])  # max(h, w)  # ratio
636
            b += im.nbytes * ratio ** 2
637
        mem_required = b * self.n / n  # GB required to cache dataset into RAM
638
        mem = psutil.virtual_memory()
639
        cache = mem_required * (1 + safety_margin) < mem.available  # to cache or not to cache, that is the question
640
        if not cache:
641
            LOGGER.info(f'{prefix}{mem_required / gb:.1f}GB RAM required, '
642
                        f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
643
                        f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
644
        return cache
645
646
    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
647
        # Cache dataset labels, check images and read shapes
648
        x = {}  # dict
649
        blood = True
650
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
651
        desc = f'{prefix}Scanning {path.parent / path.stem}...'
652
        with Pool(NUM_THREADS) as pool:
653
            pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)), blood),
654
                        desc=desc,
655
                        total=len(self.im_files),
656
                        bar_format=TQDM_BAR_FORMAT)
657
            for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
658
                nm += nm_f
659
                nf += nf_f
660
                ne += ne_f
661
                nc += nc_f
662
                if im_file:
663
                    x[im_file] = [lb, shape, segments]
664
                if msg:
665
                    msgs.append(msg)
666
                pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
667
668
        pbar.close()
669
        if msgs:
670
            LOGGER.info('\n'.join(msgs))
671
        if nf == 0:
672
            LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
673
        x['hash'] = get_hash(self.label_files + self.im_files)
674
        x['results'] = nf, nm, ne, nc, len(self.im_files)
675
        x['msgs'] = msgs  # warnings
676
        x['version'] = self.cache_version  # cache version
677
        try:
678
            np.save(path, x)  # save cache for next time
679
            path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix
680
            LOGGER.info(f'{prefix}New cache created: {path}')
681
        except Exception as e:
682
            LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}')  # not writeable
683
        return x
684
685
    def __len__(self):
686
        return len(self.im_files)
687
688
    # def __iter__(self):
689
    #     self.count = -1
690
    #     print('ran dataset iter')
691
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
692
    #     return self
693
694
    def __getitem__(self, index):
695
        index = self.indices[index]  # linear, shuffled, or image_weights
696
697
        hyp = self.hyp
698
        mosaic = self.mosaic and random.random() < hyp['mosaic']
699
        if mosaic:
700
            # Load mosaic
701
            img, labels = self.load_mosaic(index)
702
            shapes = None
703
704
            # MixUp augmentation
705
            if random.random() < hyp['mixup']:
706
                img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))
707
708
        else:
709
            # Load image
710
            img, (h0, w0), (h, w) = self.load_image(index)
711
712
            # Letterbox
713
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
714
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
715
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling
716
717
            labels = self.labels[index].copy()
718
            if labels.size:  # normalized xywh to pixel xyxy format
719
                labels[:, 1:5] = xywhn2xyxy(labels[:, 1:5], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) # I changesd 1: to 1:5
720
721
            if self.augment:
722
                img, labels = random_perspective(img,
723
                                                 labels,
724
                                                 degrees=hyp['degrees'],
725
                                                 translate=hyp['translate'],
726
                                                 scale=hyp['scale'],
727
                                                 shear=hyp['shear'],
728
                                                 perspective=hyp['perspective'])
729
730
        nl = len(labels)  # number of labels
731
        if nl:
732
            labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
733
734
        if self.augment:
735
            # Albumentations
736
            img, labels = self.albumentations(img, labels)
737
            nl = len(labels)  # update after albumentations
738
739
            # HSV color-space
740
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
741
742
            # Flip up-down
743
            if random.random() < hyp['flipud']:
744
                img = np.flipud(img)
745
                if nl:
746
                    labels[:, 2] = 1 - labels[:, 2]
747
748
            # Flip left-right
749
            if random.random() < hyp['fliplr']:
750
                img = np.fliplr(img)
751
                if nl:
752
                    labels[:, 1] = 1 - labels[:, 1]
753
754
            # Cutouts
755
            # labels = cutout(img, labels, p=0.5)
756
            # nl = len(labels)  # update after cutout
757
758
        labels_out = torch.zeros((nl, 13)) # I chnaged 6 to 13
759
        if nl:
760
            labels_out[:, 1:] = torch.from_numpy(labels)
761
762
        # Convert
763
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
764
        img = np.ascontiguousarray(img)
765
766
        return torch.from_numpy(img), labels_out, self.im_files[index], shapes
767
768
    def load_image(self, i):
769
        # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
770
        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
771
        if im is None:  # not cached in RAM
772
            if fn.exists():  # load npy
773
                im = np.load(fn)
774
            else:  # read image
775
                im = cv2.imread(f)  # BGR
776
                assert im is not None, f'Image Not Found {f}'
777
            h0, w0 = im.shape[:2]  # orig hw
778
            r = self.img_size / max(h0, w0)  # ratio
779
            if r != 1:  # if sizes are not equal
780
                interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
781
                im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
782
            return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
783
        return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized
784
785
    def cache_images_to_disk(self, i):
786
        # Saves an image as an *.npy file for faster loading
787
        f = self.npy_files[i]
788
        if not f.exists():
789
            np.save(f.as_posix(), cv2.imread(self.im_files[i]))
790
791
    def load_mosaic(self, index):
792
        # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
793
        labels4, segments4 = [], []
794
        s = self.img_size
795
        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border)  # mosaic center x, y
796
        indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
797
        random.shuffle(indices)
798
        for i, index in enumerate(indices):
799
            # Load image
800
            img, _, (h, w) = self.load_image(index)
801
802
            # place img in img4
803
            if i == 0:  # top left
804
                img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
805
                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
806
                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
807
            elif i == 1:  # top right
808
                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
809
                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
810
            elif i == 2:  # bottom left
811
                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
812
                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
813
            elif i == 3:  # bottom right
814
                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
815
                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
816
817
            img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
818
            padw = x1a - x1b
819
            padh = y1a - y1b
820
821
            # Labels
822
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
823
            if labels.size:
824
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
825
                segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
826
            labels4.append(labels)
827
            segments4.extend(segments)
828
829
        # Concat/clip labels
830
        labels4 = np.concatenate(labels4, 0)
831
        for x in (labels4[:, 1:], *segments4):
832
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
833
        # img4, labels4 = replicate(img4, labels4)  # replicate
834
835
        # Augment
836
        img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
837
        img4, labels4 = random_perspective(img4,
838
                                           labels4,
839
                                           segments4,
840
                                           degrees=self.hyp['degrees'],
841
                                           translate=self.hyp['translate'],
842
                                           scale=self.hyp['scale'],
843
                                           shear=self.hyp['shear'],
844
                                           perspective=self.hyp['perspective'],
845
                                           border=self.mosaic_border)  # border to remove
846
847
        return img4, labels4
848
849
    def load_mosaic9(self, index):
850
        # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
851
        labels9, segments9 = [], []
852
        s = self.img_size
853
        indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
854
        random.shuffle(indices)
855
        hp, wp = -1, -1  # height, width previous
856
        for i, index in enumerate(indices):
857
            # Load image
858
            img, _, (h, w) = self.load_image(index)
859
860
            # place img in img9
861
            if i == 0:  # center
862
                img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
863
                h0, w0 = h, w
864
                c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
865
            elif i == 1:  # top
866
                c = s, s - h, s + w, s
867
            elif i == 2:  # top right
868
                c = s + wp, s - h, s + wp + w, s
869
            elif i == 3:  # right
870
                c = s + w0, s, s + w0 + w, s + h
871
            elif i == 4:  # bottom right
872
                c = s + w0, s + hp, s + w0 + w, s + hp + h
873
            elif i == 5:  # bottom
874
                c = s + w0 - w, s + h0, s + w0, s + h0 + h
875
            elif i == 6:  # bottom left
876
                c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
877
            elif i == 7:  # left
878
                c = s - w, s + h0 - h, s, s + h0
879
            elif i == 8:  # top left
880
                c = s - w, s + h0 - hp - h, s, s + h0 - hp
881
882
            padx, pady = c[:2]
883
            x1, y1, x2, y2 = (max(x, 0) for x in c)  # allocate coords
884
885
            # Labels
886
            labels, segments = self.labels[index].copy(), self.segments[index].copy()
887
            if labels.size:
888
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
889
                segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
890
            labels9.append(labels)
891
            segments9.extend(segments)
892
893
            # Image
894
            img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
895
            hp, wp = h, w  # height, width previous
896
897
        # Offset
898
        yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border)  # mosaic center x, y
899
        img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
900
901
        # Concat/clip labels
902
        labels9 = np.concatenate(labels9, 0)
903
        labels9[:, [1, 3]] -= xc
904
        labels9[:, [2, 4]] -= yc
905
        c = np.array([xc, yc])  # centers
906
        segments9 = [x - c for x in segments9]
907
908
        for x in (labels9[:, 1:], *segments9):
909
            np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
910
        # img9, labels9 = replicate(img9, labels9)  # replicate
911
912
        # Augment
913
        img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
914
        img9, labels9 = random_perspective(img9,
915
                                           labels9,
916
                                           segments9,
917
                                           degrees=self.hyp['degrees'],
918
                                           translate=self.hyp['translate'],
919
                                           scale=self.hyp['scale'],
920
                                           shear=self.hyp['shear'],
921
                                           perspective=self.hyp['perspective'],
922
                                           border=self.mosaic_border)  # border to remove
923
924
        return img9, labels9
925
926
    @staticmethod
927
    def collate_fn(batch):
928
        im, label, path, shapes = zip(*batch)  # transposed
929
        for i, lb in enumerate(label):
930
            lb[:, 0] = i  # add target image index for build_targets()
931
        return torch.stack(im, 0), torch.cat(label, 0), path, shapes
932
933
    @staticmethod
934
    def collate_fn4(batch):
935
        im, label, path, shapes = zip(*batch)  # transposed
936
        n = len(shapes) // 4
937
        im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
938
939
        ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
940
        wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
941
        s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]])  # scale
942
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
943
            i *= 4
944
            if random.random() < 0.5:
945
                im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
946
                                    align_corners=False)[0].type(im[i].type())
947
                lb = label[i]
948
            else:
949
                im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
950
                lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
951
            im4.append(im1)
952
            label4.append(lb)
953
954
        for i, lb in enumerate(label4):
955
            lb[:, 0] = i  # add target image index for build_targets()
956
957
        return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
958
959
960
# Ancillary functions --------------------------------------------------------------------------------------------------
961
def flatten_recursive(path=DATASETS_DIR / 'coco128'):
962
    # Flatten a recursive directory by bringing all files to top level
963
    new_path = Path(f'{str(path)}_flat')
964
    if os.path.exists(new_path):
965
        shutil.rmtree(new_path)  # delete output folder
966
    os.makedirs(new_path)  # make new output folder
967
    for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
968
        shutil.copyfile(file, new_path / Path(file).name)
969
970
971
def extract_boxes(path=DATASETS_DIR / 'coco128'):  # from utils.dataloaders import *; extract_boxes()
972
    # Convert detection dataset into classification dataset, with one directory per class
973
    path = Path(path)  # images dir
974
    shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None  # remove existing
975
    files = list(path.rglob('*.*'))
976
    n = len(files)  # number of files
977
    for im_file in tqdm(files, total=n):
978
        if im_file.suffix[1:] in IMG_FORMATS:
979
            # image
980
            im = cv2.imread(str(im_file))[..., ::-1]  # BGR to RGB
981
            h, w = im.shape[:2]
982
983
            # labels
984
            lb_file = Path(img2label_paths([str(im_file)])[0])
985
            if Path(lb_file).exists():
986
                with open(lb_file) as f:
987
                    lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32)  # labels
988
989
                for j, x in enumerate(lb):
990
                    c = int(x[0])  # class
991
                    f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg'  # new filename
992
                    if not f.parent.is_dir():
993
                        f.parent.mkdir(parents=True)
994
995
                    b = x[1:] * [w, h, w, h]  # box
996
                    # b[2:] = b[2:].max()  # rectangle to square
997
                    b[2:] = b[2:] * 1.2 + 3  # pad
998
                    b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
999
1000
                    b[[0, 2]] = np.clip(b[[0, 2]], 0, w)  # clip boxes outside of image
1001
                    b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
1002
                    assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
1003
1004
1005
def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
1006
    """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
1007
    Usage: from utils.dataloaders import *; autosplit()
1008
    Arguments
1009
        path:            Path to images directory
1010
        weights:         Train, val, test weights (list, tuple)
1011
        annotated_only:  Only use images with an annotated txt file
1012
    """
1013
    path = Path(path)  # images dir
1014
    files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS)  # image files only
1015
    n = len(files)  # number of files
1016
    random.seed(0)  # for reproducibility
1017
    indices = random.choices([0, 1, 2], weights=weights, k=n)  # assign each image to a split
1018
1019
    txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt']  # 3 txt files
1020
    for x in txt:
1021
        if (path.parent / x).exists():
1022
            (path.parent / x).unlink()  # remove existing
1023
1024
    print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
1025
    for i, img in tqdm(zip(indices, files), total=n):
1026
        if not annotated_only or Path(img2label_paths([str(img)])[0]).exists():  # check label
1027
            with open(path.parent / txt[i], 'a') as f:
1028
                f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n')  # add image to txt file
1029
1030
1031
def verify_image_label(args, blood=True):
1032
    # Verify one image-label pair
1033
    im_file, lb_file, prefix = args
1034
    nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', []  # number (missing, found, empty, corrupt), message, segments
1035
    try:
1036
        # verify images
1037
        im = Image.open(im_file)
1038
        im.verify()  # PIL verify
1039
        shape = exif_size(im)  # image size
1040
        assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
1041
        assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
1042
        if im.format.lower() in ('jpg', 'jpeg'):
1043
            with open(im_file, 'rb') as f:
1044
                f.seek(-2, 2)
1045
                if f.read() != b'\xff\xd9':  # corrupt JPEG
1046
                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
1047
                    msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
1048
1049
        # verify labels
1050
        if os.path.isfile(lb_file):
1051
            nf = 1  # label found
1052
            with open(lb_file) as f:
1053
                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
1054
                if not blood and any(len(x) > 6 for x in lb):  # is segment
1055
                    classes = np.array([x[0] for x in lb], dtype=np.float32)
1056
                    segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb]  # (cls, xy1...)
1057
                    lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
1058
                lb = np.array(lb, dtype=np.float32)
1059
            nl = len(lb)
1060
            if nl:
1061
                assert lb.shape[1] == 12, f'labels require 5 columns, {lb.shape[1]} columns detected' # I changed here
1062
                assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
1063
                assert (lb[:, 1:5] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:5][lb[:, 1:5] > 1]}'
1064
                _, i = np.unique(lb, axis=0, return_index=True)
1065
                if len(i) < nl:  # duplicate row check
1066
                    lb = lb[i]  # remove duplicates
1067
                    if segments:
1068
                        segments = [segments[x] for x in i]
1069
                    msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
1070
            else:
1071
                ne = 1  # label empty
1072
                lb = np.zeros((0, 12), dtype=np.float32)
1073
        else:
1074
            nm = 1  # label missing
1075
            lb = np.zeros((0, 12), dtype=np.float32)
1076
        return im_file, lb, shape, segments, nm, nf, ne, nc, msg
1077
    except Exception as e:
1078
        nc = 1
1079
        msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
1080
        return [None, None, None, None, nm, nf, ne, nc, msg]
1081
1082
1083
class HUBDatasetStats():
1084
    """ Class for generating HUB dataset JSON and `-hub` dataset directory
1085
1086
    Arguments
1087
        path:           Path to data.yaml or data.zip (with data.yaml inside data.zip)
1088
        autodownload:   Attempt to download dataset if not found locally
1089
1090
    Usage
1091
        from utils.dataloaders import HUBDatasetStats
1092
        stats = HUBDatasetStats('coco128.yaml', autodownload=True)  # usage 1
1093
        stats = HUBDatasetStats('path/to/coco128.zip')  # usage 2
1094
        stats.get_json(save=False)
1095
        stats.process_images()
1096
    """
1097
1098
    def __init__(self, path='coco128.yaml', autodownload=False):
1099
        # Initialize class
1100
        zipped, data_dir, yaml_path = self._unzip(Path(path))
1101
        try:
1102
            with open(check_yaml(yaml_path), errors='ignore') as f:
1103
                data = yaml.safe_load(f)  # data dict
1104
                if zipped:
1105
                    data['path'] = data_dir
1106
        except Exception as e:
1107
            raise Exception('error/HUB/dataset_stats/yaml_load') from e
1108
1109
        check_dataset(data, autodownload)  # download dataset if missing
1110
        self.hub_dir = Path(data['path'] + '-hub')
1111
        self.im_dir = self.hub_dir / 'images'
1112
        self.im_dir.mkdir(parents=True, exist_ok=True)  # makes /images
1113
        self.stats = {'nc': data['nc'], 'names': list(data['names'].values())}  # statistics dictionary
1114
        self.data = data
1115
1116
    @staticmethod
1117
    def _find_yaml(dir):
1118
        # Return data.yaml file
1119
        files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml'))  # try root level first and then recursive
1120
        assert files, f'No *.yaml file found in {dir}'
1121
        if len(files) > 1:
1122
            files = [f for f in files if f.stem == dir.stem]  # prefer *.yaml files that match dir name
1123
            assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
1124
        assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
1125
        return files[0]
1126
1127
    def _unzip(self, path):
1128
        # Unzip data.zip
1129
        if not str(path).endswith('.zip'):  # path is data.yaml
1130
            return False, None, path
1131
        assert Path(path).is_file(), f'Error unzipping {path}, file not found'
1132
        unzip_file(path, path=path.parent)
1133
        dir = path.with_suffix('')  # dataset directory == zip name
1134
        assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
1135
        return True, str(dir), self._find_yaml(dir)  # zipped, data_dir, yaml_path
1136
1137
    def _hub_ops(self, f, max_dim=1920):
1138
        # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
1139
        f_new = self.im_dir / Path(f).name  # dataset-hub image filename
1140
        try:  # use PIL
1141
            im = Image.open(f)
1142
            r = max_dim / max(im.height, im.width)  # ratio
1143
            if r < 1.0:  # image too large
1144
                im = im.resize((int(im.width * r), int(im.height * r)))
1145
            im.save(f_new, 'JPEG', quality=50, optimize=True)  # save
1146
        except Exception as e:  # use OpenCV
1147
            LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
1148
            im = cv2.imread(f)
1149
            im_height, im_width = im.shape[:2]
1150
            r = max_dim / max(im_height, im_width)  # ratio
1151
            if r < 1.0:  # image too large
1152
                im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
1153
            cv2.imwrite(str(f_new), im)
1154
1155
    def get_json(self, save=False, verbose=False):
1156
        # Return dataset JSON for Ultralytics HUB
1157
        def _round(labels):
1158
            # Update labels to integer class and 6 decimal place floats
1159
            return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
1160
1161
        for split in 'train', 'val', 'test':
1162
            if self.data.get(split) is None:
1163
                self.stats[split] = None  # i.e. no test set
1164
                continue
1165
            dataset = LoadImagesAndLabels(self.data[split])  # load dataset
1166
            x = np.array([
1167
                np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
1168
                for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')])  # shape(128x80)
1169
            self.stats[split] = {
1170
                'instance_stats': {
1171
                    'total': int(x.sum()),
1172
                    'per_class': x.sum(0).tolist()},
1173
                'image_stats': {
1174
                    'total': dataset.n,
1175
                    'unlabelled': int(np.all(x == 0, 1).sum()),
1176
                    'per_class': (x > 0).sum(0).tolist()},
1177
                'labels': [{
1178
                    str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
1179
1180
        # Save, print and return
1181
        if save:
1182
            stats_path = self.hub_dir / 'stats.json'
1183
            print(f'Saving {stats_path.resolve()}...')
1184
            with open(stats_path, 'w') as f:
1185
                json.dump(self.stats, f)  # save stats.json
1186
        if verbose:
1187
            print(json.dumps(self.stats, indent=2, sort_keys=False))
1188
        return self.stats
1189
1190
    def process_images(self):
1191
        # Compress images for Ultralytics HUB
1192
        for split in 'train', 'val', 'test':
1193
            if self.data.get(split) is None:
1194
                continue
1195
            dataset = LoadImagesAndLabels(self.data[split])  # load dataset
1196
            desc = f'{split} images'
1197
            for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
1198
                pass
1199
        print(f'Done. All images saved to {self.im_dir}')
1200
        return self.im_dir
1201
1202
1203
# Classification dataloaders -------------------------------------------------------------------------------------------
1204
class ClassificationDataset(torchvision.datasets.ImageFolder):
1205
    """
1206
    YOLOv5 Classification Dataset.
1207
    Arguments
1208
        root:  Dataset path
1209
        transform:  torchvision transforms, used by default
1210
        album_transform: Albumentations transforms, used if installed
1211
    """
1212
1213
    def __init__(self, root, augment, imgsz, cache=False):
1214
        super().__init__(root=root)
1215
        self.torch_transforms = classify_transforms(imgsz)
1216
        self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
1217
        self.cache_ram = cache is True or cache == 'ram'
1218
        self.cache_disk = cache == 'disk'
1219
        self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples]  # file, index, npy, im
1220
1221
    def __getitem__(self, i):
1222
        f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), image
1223
        if self.cache_ram and im is None:
1224
            im = self.samples[i][3] = cv2.imread(f)
1225
        elif self.cache_disk:
1226
            if not fn.exists():  # load npy
1227
                np.save(fn.as_posix(), cv2.imread(f))
1228
            im = np.load(fn)
1229
        else:  # read image
1230
            im = cv2.imread(f)  # BGR
1231
        if self.album_transforms:
1232
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
1233
        else:
1234
            sample = self.torch_transforms(im)
1235
        return sample, j
1236
1237
1238
def create_classification_dataloader(path,
1239
                                     imgsz=224,
1240
                                     batch_size=16,
1241
                                     augment=True,
1242
                                     cache=False,
1243
                                     rank=-1,
1244
                                     workers=8,
1245
                                     shuffle=True):
1246
    # Returns Dataloader object to be used with YOLOv5 Classifier
1247
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
1248
        dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
1249
    batch_size = min(batch_size, len(dataset))
1250
    nd = torch.cuda.device_count()
1251
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
1252
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
1253
    generator = torch.Generator()
1254
    generator.manual_seed(6148914691236517205 + RANK)
1255
    return InfiniteDataLoader(dataset,
1256
                              batch_size=batch_size,
1257
                              shuffle=shuffle and sampler is None,
1258
                              num_workers=nw,
1259
                              sampler=sampler,
1260
                              pin_memory=PIN_MEMORY,
1261
                              worker_init_fn=seed_worker,
1262
                              generator=generator)  # or DataLoader(persistent_workers=True)