Switch to unified view

a b/landmark_extraction/utils/datasets.py
1
# Dataset utils and dataloaders
2
3
import glob
4
import logging
5
import math
6
import os
7
import random
8
import shutil
9
import time
10
from itertools import repeat
11
from multiprocessing.pool import ThreadPool
12
from pathlib import Path
13
from threading import Thread
14
15
import cv2
16
import numpy as np
17
import torch
18
import torch.nn.functional as F
19
from PIL import Image, ExifTags
20
from torch.utils.data import Dataset
21
from tqdm import tqdm
22
23
import pickle
24
from copy import deepcopy
25
#from pycocotools import mask as maskUtils
26
from torchvision.utils import save_image
27
from torchvision.ops import roi_pool, roi_align, ps_roi_pool, ps_roi_align
28
29
from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
30
    resample_segments, clean_str
31
from utils.torch_utils import torch_distributed_zero_first
32
33
# Parameters
34
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
35
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo']  # acceptable image suffixes
36
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv']  # acceptable video suffixes
37
logger = logging.getLogger(__name__)
38
39
# Get orientation exif tag
40
for orientation in ExifTags.TAGS.keys():
41
    if ExifTags.TAGS[orientation] == 'Orientation':
42
        break
43
44
45
def get_hash(files):
46
    # Returns a single hash value of a list of files
47
    return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
48
49
50
def exif_size(img):
51
    # Returns exif-corrected PIL size
52
    s = img.size  # (width, height)
53
    try:
54
        rotation = dict(img._getexif().items())[orientation]
55
        if rotation == 6:  # rotation 270
56
            s = (s[1], s[0])
57
        elif rotation == 8:  # rotation 90
58
            s = (s[1], s[0])
59
    except:
60
        pass
61
62
    return s
63
64
65
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
66
                      rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
67
    # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
68
    with torch_distributed_zero_first(rank):
69
        dataset = LoadImagesAndLabels(path, imgsz, batch_size,
70
                                      augment=augment,  # augment images
71
                                      hyp=hyp,  # augmentation hyperparameters
72
                                      rect=rect,  # rectangular training
73
                                      cache_images=cache,
74
                                      single_cls=opt.single_cls,
75
                                      stride=int(stride),
76
                                      pad=pad,
77
                                      image_weights=image_weights,
78
                                      prefix=prefix)
79
80
    batch_size = min(batch_size, len(dataset))
81
    nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers])  # number of workers
82
    sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
83
    loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
84
    # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
85
    dataloader = loader(dataset,
86
                        batch_size=batch_size,
87
                        num_workers=nw,
88
                        sampler=sampler,
89
                        pin_memory=True,
90
                        collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
91
    return dataloader, dataset
92
93
94
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
95
    """ Dataloader that reuses workers
96
97
    Uses same syntax as vanilla DataLoader
98
    """
99
100
    def __init__(self, *args, **kwargs):
101
        super().__init__(*args, **kwargs)
102
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
103
        self.iterator = super().__iter__()
104
105
    def __len__(self):
106
        return len(self.batch_sampler.sampler)
107
108
    def __iter__(self):
109
        for i in range(len(self)):
110
            yield next(self.iterator)
111
112
113
class _RepeatSampler(object):
114
    """ Sampler that repeats forever
115
116
    Args:
117
        sampler (Sampler)
118
    """
119
120
    def __init__(self, sampler):
121
        self.sampler = sampler
122
123
    def __iter__(self):
124
        while True:
125
            yield from iter(self.sampler)
126
127
128
class LoadImages:  # for inference
129
    def __init__(self, path, img_size=640, stride=32):
130
        p = str(Path(path).absolute())  # os-agnostic absolute path
131
        if '*' in p:
132
            files = sorted(glob.glob(p, recursive=True))  # glob
133
        elif os.path.isdir(p):
134
            files = sorted(glob.glob(os.path.join(p, '*.*')))  # dir
135
        elif os.path.isfile(p):
136
            files = [p]  # files
137
        else:
138
            raise Exception(f'ERROR: {p} does not exist')
139
140
        images = [x for x in files if x.split('.')[-1].lower() in img_formats]
141
        videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
142
        ni, nv = len(images), len(videos)
143
144
        self.img_size = img_size
145
        self.stride = stride
146
        self.files = images + videos
147
        self.nf = ni + nv  # number of files
148
        self.video_flag = [False] * ni + [True] * nv
149
        self.mode = 'image'
150
        if any(videos):
151
            self.new_video(videos[0])  # new video
152
        else:
153
            self.cap = None
154
        assert self.nf > 0, f'No images or videos found in {p}. ' \
155
                            f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
156
157
    def __iter__(self):
158
        self.count = 0
159
        return self
160
161
    def __next__(self):
162
        if self.count == self.nf:
163
            raise StopIteration
164
        path = self.files[self.count]
165
166
        if self.video_flag[self.count]:
167
            # Read video
168
            self.mode = 'video'
169
            ret_val, img0 = self.cap.read()
170
            if not ret_val:
171
                self.count += 1
172
                self.cap.release()
173
                if self.count == self.nf:  # last video
174
                    raise StopIteration
175
                else:
176
                    path = self.files[self.count]
177
                    self.new_video(path)
178
                    ret_val, img0 = self.cap.read()
179
180
            self.frame += 1
181
            print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
182
183
        else:
184
            # Read image
185
            self.count += 1
186
            img0 = cv2.imread(path)  # BGR
187
            assert img0 is not None, 'Image Not Found ' + path
188
            #print(f'image {self.count}/{self.nf} {path}: ', end='')
189
190
        # Padded resize
191
        img = letterbox(img0, self.img_size, stride=self.stride)[0]
192
193
        # Convert
194
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
195
        img = np.ascontiguousarray(img)
196
197
        return path, img, img0, self.cap
198
199
    def new_video(self, path):
200
        self.frame = 0
201
        self.cap = cv2.VideoCapture(path)
202
        self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
203
204
    def __len__(self):
205
        return self.nf  # number of files
206
207
208
class LoadWebcam:  # for inference
209
    def __init__(self, pipe='0', img_size=640, stride=32):
210
        self.img_size = img_size
211
        self.stride = stride
212
213
        if pipe.isnumeric():
214
            pipe = eval(pipe)  # local camera
215
        # pipe = 'rtsp://192.168.1.64/1'  # IP camera
216
        # pipe = 'rtsp://username:password@192.168.1.64/1'  # IP camera with login
217
        # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg'  # IP golf camera
218
219
        self.pipe = pipe
220
        self.cap = cv2.VideoCapture(pipe)  # video capture object
221
        self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3)  # set buffer size
222
223
    def __iter__(self):
224
        self.count = -1
225
        return self
226
227
    def __next__(self):
228
        self.count += 1
229
        if cv2.waitKey(1) == ord('q'):  # q to quit
230
            self.cap.release()
231
            cv2.destroyAllWindows()
232
            raise StopIteration
233
234
        # Read frame
235
        if self.pipe == 0:  # local camera
236
            ret_val, img0 = self.cap.read()
237
            img0 = cv2.flip(img0, 1)  # flip left-right
238
        else:  # IP camera
239
            n = 0
240
            while True:
241
                n += 1
242
                self.cap.grab()
243
                if n % 30 == 0:  # skip frames
244
                    ret_val, img0 = self.cap.retrieve()
245
                    if ret_val:
246
                        break
247
248
        # Print
249
        assert ret_val, f'Camera Error {self.pipe}'
250
        img_path = 'webcam.jpg'
251
        print(f'webcam {self.count}: ', end='')
252
253
        # Padded resize
254
        img = letterbox(img0, self.img_size, stride=self.stride)[0]
255
256
        # Convert
257
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
258
        img = np.ascontiguousarray(img)
259
260
        return img_path, img, img0, None
261
262
    def __len__(self):
263
        return 0
264
265
266
class LoadStreams:  # multiple IP or RTSP cameras
267
    def __init__(self, sources='streams.txt', img_size=640, stride=32):
268
        self.mode = 'stream'
269
        self.img_size = img_size
270
        self.stride = stride
271
272
        if os.path.isfile(sources):
273
            with open(sources, 'r') as f:
274
                sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
275
        else:
276
            sources = [sources]
277
278
        n = len(sources)
279
        self.imgs = [None] * n
280
        self.sources = [clean_str(x) for x in sources]  # clean source names for later
281
        for i, s in enumerate(sources):
282
            # Start the thread to read frames from the video stream
283
            print(f'{i + 1}/{n}: {s}... ', end='')
284
            url = eval(s) if s.isnumeric() else s
285
            if 'youtube.com/' in str(url) or 'youtu.be/' in str(url):  # if source is YouTube video
286
                check_requirements(('pafy', 'youtube_dl'))
287
                import pafy
288
                url = pafy.new(url).getbest(preftype="mp4").url
289
            cap = cv2.VideoCapture(url)
290
            assert cap.isOpened(), f'Failed to open {s}'
291
            w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
292
            h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
293
            self.fps = cap.get(cv2.CAP_PROP_FPS) % 100
294
295
            _, self.imgs[i] = cap.read()  # guarantee first frame
296
            thread = Thread(target=self.update, args=([i, cap]), daemon=True)
297
            print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
298
            thread.start()
299
        print('')  # newline
300
301
        # check for common shapes
302
        s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0)  # shapes
303
        self.rect = np.unique(s, axis=0).shape[0] == 1  # rect inference if all shapes equal
304
        if not self.rect:
305
            print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
306
307
    def update(self, index, cap):
308
        # Read next stream frame in a daemon thread
309
        n = 0
310
        while cap.isOpened():
311
            n += 1
312
            # _, self.imgs[index] = cap.read()
313
            cap.grab()
314
            if n == 4:  # read every 4th frame
315
                success, im = cap.retrieve()
316
                self.imgs[index] = im if success else self.imgs[index] * 0
317
                n = 0
318
            time.sleep(1 / self.fps)  # wait time
319
320
    def __iter__(self):
321
        self.count = -1
322
        return self
323
324
    def __next__(self):
325
        self.count += 1
326
        img0 = self.imgs.copy()
327
        if cv2.waitKey(1) == ord('q'):  # q to quit
328
            cv2.destroyAllWindows()
329
            raise StopIteration
330
331
        # Letterbox
332
        img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
333
334
        # Stack
335
        img = np.stack(img, 0)
336
337
        # Convert
338
        img = img[:, :, :, ::-1].transpose(0, 3, 1, 2)  # BGR to RGB, to bsx3x416x416
339
        img = np.ascontiguousarray(img)
340
341
        return self.sources, img, img0, None
342
343
    def __len__(self):
344
        return 0  # 1E12 frames = 32 streams at 30 FPS for 30 years
345
346
347
def img2label_paths(img_paths):
348
    # Define label paths as a function of image paths
349
    sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep  # /images/, /labels/ substrings
350
    return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
351
352
353
class LoadImagesAndLabels(Dataset):  # for training/testing
354
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
355
                 cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
356
        self.img_size = img_size
357
        self.augment = augment
358
        self.hyp = hyp
359
        self.image_weights = image_weights
360
        self.rect = False if image_weights else rect
361
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
362
        self.mosaic_border = [-img_size // 2, -img_size // 2]
363
        self.stride = stride
364
        self.path = path        
365
        #self.albumentations = Albumentations() if augment else None
366
367
        try:
368
            f = []  # image files
369
            for p in path if isinstance(path, list) else [path]:
370
                p = Path(p)  # os-agnostic
371
                if p.is_dir():  # dir
372
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
373
                    # f = list(p.rglob('**/*.*'))  # pathlib
374
                elif p.is_file():  # file
375
                    with open(p, 'r') as t:
376
                        t = t.read().strip().splitlines()
377
                        parent = str(p.parent) + os.sep
378
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
379
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
380
                else:
381
                    raise Exception(f'{prefix}{p} does not exist')
382
            self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
383
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats])  # pathlib
384
            assert self.img_files, f'{prefix}No images found'
385
        except Exception as e:
386
            raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
387
388
        # Check cache
389
        self.label_files = img2label_paths(self.img_files)  # labels
390
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')  # cached labels
391
        if cache_path.is_file():
392
            cache, exists = torch.load(cache_path), True  # load
393
            #if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache:  # changed
394
            #    cache, exists = self.cache_labels(cache_path, prefix), False  # re-cache
395
        else:
396
            cache, exists = self.cache_labels(cache_path, prefix), False  # cache
397
398
        # Display cache
399
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupted, total
400
        if exists:
401
            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
402
            tqdm(None, desc=prefix + d, total=n, initial=n)  # display cache results
403
        assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
404
405
        # Read cache
406
        cache.pop('hash')  # remove hash
407
        cache.pop('version')  # remove version
408
        labels, shapes, self.segments = zip(*cache.values())
409
        self.labels = list(labels)
410
        self.shapes = np.array(shapes, dtype=np.float64)
411
        self.img_files = list(cache.keys())  # update
412
        self.label_files = img2label_paths(cache.keys())  # update
413
        if single_cls:
414
            for x in self.labels:
415
                x[:, 0] = 0
416
417
        n = len(shapes)  # number of images
418
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
419
        nb = bi[-1] + 1  # number of batches
420
        self.batch = bi  # batch index of image
421
        self.n = n
422
        self.indices = range(n)
423
424
        # Rectangular Training
425
        if self.rect:
426
            # Sort by aspect ratio
427
            s = self.shapes  # wh
428
            ar = s[:, 1] / s[:, 0]  # aspect ratio
429
            irect = ar.argsort()
430
            self.img_files = [self.img_files[i] for i in irect]
431
            self.label_files = [self.label_files[i] for i in irect]
432
            self.labels = [self.labels[i] for i in irect]
433
            self.shapes = s[irect]  # wh
434
            ar = ar[irect]
435
436
            # Set training image shapes
437
            shapes = [[1, 1]] * nb
438
            for i in range(nb):
439
                ari = ar[bi == i]
440
                mini, maxi = ari.min(), ari.max()
441
                if maxi < 1:
442
                    shapes[i] = [maxi, 1]
443
                elif mini > 1:
444
                    shapes[i] = [1, 1 / mini]
445
446
            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
447
448
        # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
449
        self.imgs = [None] * n
450
        if cache_images:
451
            if cache_images == 'disk':
452
                self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
453
                self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
454
                self.im_cache_dir.mkdir(parents=True, exist_ok=True)
455
            gb = 0  # Gigabytes of cached images
456
            self.img_hw0, self.img_hw = [None] * n, [None] * n
457
            results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
458
            pbar = tqdm(enumerate(results), total=n)
459
            for i, x in pbar:
460
                if cache_images == 'disk':
461
                    if not self.img_npy[i].exists():
462
                        np.save(self.img_npy[i].as_posix(), x[0])
463
                    gb += self.img_npy[i].stat().st_size
464
                else:
465
                    self.imgs[i], self.img_hw0[i], self.img_hw[i] = x
466
                    gb += self.imgs[i].nbytes
467
                pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
468
            pbar.close()
469
470
    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
471
        # Cache dataset labels, check images and read shapes
472
        x = {}  # dict
473
        nm, nf, ne, nc = 0, 0, 0, 0  # number missing, found, empty, duplicate
474
        pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
475
        for i, (im_file, lb_file) in enumerate(pbar):
476
            try:
477
                # verify images
478
                im = Image.open(im_file)
479
                im.verify()  # PIL verify
480
                shape = exif_size(im)  # image size
481
                segments = []  # instance segments
482
                assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
483
                assert im.format.lower() in img_formats, f'invalid image format {im.format}'
484
485
                # verify labels
486
                if os.path.isfile(lb_file):
487
                    nf += 1  # label found
488
                    with open(lb_file, 'r') as f:
489
                        l = [x.split() for x in f.read().strip().splitlines()]
490
                        if any([len(x) > 8 for x in l]):  # is segment
491
                            classes = np.array([x[0] for x in l], dtype=np.float32)
492
                            segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l]  # (cls, xy1...)
493
                            l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
494
                        l = np.array(l, dtype=np.float32)
495
                    if len(l):
496
                        assert l.shape[1] == 5, 'labels require 5 columns each'
497
                        assert (l >= 0).all(), 'negative labels'
498
                        assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
499
                        assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
500
                    else:
501
                        ne += 1  # label empty
502
                        l = np.zeros((0, 5), dtype=np.float32)
503
                else:
504
                    nm += 1  # label missing
505
                    l = np.zeros((0, 5), dtype=np.float32)
506
                x[im_file] = [l, shape, segments]
507
            except Exception as e:
508
                nc += 1
509
                print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
510
511
            pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
512
                        f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
513
        pbar.close()
514
515
        if nf == 0:
516
            print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
517
518
        x['hash'] = get_hash(self.label_files + self.img_files)
519
        x['results'] = nf, nm, ne, nc, i + 1
520
        x['version'] = 0.1  # cache version
521
        torch.save(x, path)  # save for next time
522
        logging.info(f'{prefix}New cache created: {path}')
523
        return x
524
525
    def __len__(self):
526
        return len(self.img_files)
527
528
    # def __iter__(self):
529
    #     self.count = -1
530
    #     print('ran dataset iter')
531
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
532
    #     return self
533
534
    def __getitem__(self, index):
535
        index = self.indices[index]  # linear, shuffled, or image_weights
536
537
        hyp = self.hyp
538
        mosaic = self.mosaic and random.random() < hyp['mosaic']
539
        if mosaic:
540
            # Load mosaic
541
            if random.random() < 0.8:
542
                img, labels = load_mosaic(self, index)
543
            else:
544
                img, labels = load_mosaic9(self, index)
545
            shapes = None
546
547
            # MixUp https://arxiv.org/pdf/1710.09412.pdf
548
            if random.random() < hyp['mixup']:
549
                if random.random() < 0.8:
550
                    img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
551
                else:
552
                    img2, labels2 = load_mosaic9(self, random.randint(0, len(self.labels) - 1))
553
                r = np.random.beta(8.0, 8.0)  # mixup ratio, alpha=beta=8.0
554
                img = (img * r + img2 * (1 - r)).astype(np.uint8)
555
                labels = np.concatenate((labels, labels2), 0)
556
557
        else:
558
            # Load image
559
            img, (h0, w0), (h, w) = load_image(self, index)
560
561
            # Letterbox
562
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
563
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
564
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling
565
566
            labels = self.labels[index].copy()
567
            if labels.size:  # normalized xywh to pixel xyxy format
568
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
569
570
        if self.augment:
571
            # Augment imagespace
572
            if not mosaic:
573
                img, labels = random_perspective(img, labels,
574
                                                 degrees=hyp['degrees'],
575
                                                 translate=hyp['translate'],
576
                                                 scale=hyp['scale'],
577
                                                 shear=hyp['shear'],
578
                                                 perspective=hyp['perspective'])
579
            
580
            
581
            #img, labels = self.albumentations(img, labels)
582
583
            # Augment colorspace
584
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
585
586
            # Apply cutouts
587
            # if random.random() < 0.9:
588
            #     labels = cutout(img, labels)
589
            
590
            if random.random() < hyp['paste_in']:
591
                sample_labels, sample_images, sample_masks = [], [], [] 
592
                while len(sample_labels) < 30:
593
                    sample_labels_, sample_images_, sample_masks_ = load_samples(self, random.randint(0, len(self.labels) - 1))
594
                    sample_labels += sample_labels_
595
                    sample_images += sample_images_
596
                    sample_masks += sample_masks_
597
                    #print(len(sample_labels))
598
                    if len(sample_labels) == 0:
599
                        break
600
                labels = pastein(img, labels, sample_labels, sample_images, sample_masks)
601
602
        nL = len(labels)  # number of labels
603
        if nL:
604
            labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])  # convert xyxy to xywh
605
            labels[:, [2, 4]] /= img.shape[0]  # normalized height 0-1
606
            labels[:, [1, 3]] /= img.shape[1]  # normalized width 0-1
607
608
        if self.augment:
609
            # flip up-down
610
            if random.random() < hyp['flipud']:
611
                img = np.flipud(img)
612
                if nL:
613
                    labels[:, 2] = 1 - labels[:, 2]
614
615
            # flip left-right
616
            if random.random() < hyp['fliplr']:
617
                img = np.fliplr(img)
618
                if nL:
619
                    labels[:, 1] = 1 - labels[:, 1]
620
621
        labels_out = torch.zeros((nL, 6))
622
        if nL:
623
            labels_out[:, 1:] = torch.from_numpy(labels)
624
625
        # Convert
626
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
627
        img = np.ascontiguousarray(img)
628
629
        return torch.from_numpy(img), labels_out, self.img_files[index], shapes
630
631
    @staticmethod
632
    def collate_fn(batch):
633
        img, label, path, shapes = zip(*batch)  # transposed
634
        for i, l in enumerate(label):
635
            l[:, 0] = i  # add target image index for build_targets()
636
        return torch.stack(img, 0), torch.cat(label, 0), path, shapes
637
638
    @staticmethod
639
    def collate_fn4(batch):
640
        img, label, path, shapes = zip(*batch)  # transposed
641
        n = len(shapes) // 4
642
        img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
643
644
        ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
645
        wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
646
        s = torch.tensor([[1, 1, .5, .5, .5, .5]])  # scale
647
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
648
            i *= 4
649
            if random.random() < 0.5:
650
                im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
651
                    0].type(img[i].type())
652
                l = label[i]
653
            else:
654
                im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
655
                l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
656
            img4.append(im)
657
            label4.append(l)
658
659
        for i, l in enumerate(label4):
660
            l[:, 0] = i  # add target image index for build_targets()
661
662
        return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
663
664
665
# Ancillary functions --------------------------------------------------------------------------------------------------
666
def load_image(self, index):
667
    # loads 1 image from dataset, returns img, original hw, resized hw
668
    img = self.imgs[index]
669
    if img is None:  # not cached
670
        path = self.img_files[index]
671
        img = cv2.imread(path)  # BGR
672
        assert img is not None, 'Image Not Found ' + path
673
        h0, w0 = img.shape[:2]  # orig hw
674
        r = self.img_size / max(h0, w0)  # resize image to img_size
675
        if r != 1:  # always resize down, only resize up if training with augmentation
676
            interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
677
            img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
678
        return img, (h0, w0), img.shape[:2]  # img, hw_original, hw_resized
679
    else:
680
        return self.imgs[index], self.img_hw0[index], self.img_hw[index]  # img, hw_original, hw_resized
681
682
683
def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
684
    r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1  # random gains
685
    hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
686
    dtype = img.dtype  # uint8
687
688
    x = np.arange(0, 256, dtype=np.int16)
689
    lut_hue = ((x * r[0]) % 180).astype(dtype)
690
    lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
691
    lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
692
693
    img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
694
    cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)  # no return needed
695
696
697
def hist_equalize(img, clahe=True, bgr=False):
698
    # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
699
    yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
700
    if clahe:
701
        c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
702
        yuv[:, :, 0] = c.apply(yuv[:, :, 0])
703
    else:
704
        yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0])  # equalize Y channel histogram
705
    return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB)  # convert YUV image to RGB
706
707
708
def load_mosaic(self, index):
709
    # loads images in a 4-mosaic
710
711
    labels4, segments4 = [], []
712
    s = self.img_size
713
    yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border]  # mosaic center x, y
714
    indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
715
    for i, index in enumerate(indices):
716
        # Load image
717
        img, _, (h, w) = load_image(self, index)
718
719
        # place img in img4
720
        if i == 0:  # top left
721
            img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
722
            x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
723
            x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
724
        elif i == 1:  # top right
725
            x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
726
            x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
727
        elif i == 2:  # bottom left
728
            x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
729
            x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
730
        elif i == 3:  # bottom right
731
            x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
732
            x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
733
734
        img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
735
        padw = x1a - x1b
736
        padh = y1a - y1b
737
738
        # Labels
739
        labels, segments = self.labels[index].copy(), self.segments[index].copy()
740
        if labels.size:
741
            labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
742
            segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
743
        labels4.append(labels)
744
        segments4.extend(segments)
745
746
    # Concat/clip labels
747
    labels4 = np.concatenate(labels4, 0)
748
    for x in (labels4[:, 1:], *segments4):
749
        np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
750
    # img4, labels4 = replicate(img4, labels4)  # replicate
751
752
    # Augment
753
    #img4, labels4, segments4 = remove_background(img4, labels4, segments4)
754
    #sample_segments(img4, labels4, segments4, probability=self.hyp['copy_paste'])
755
    img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
756
    img4, labels4 = random_perspective(img4, labels4, segments4,
757
                                       degrees=self.hyp['degrees'],
758
                                       translate=self.hyp['translate'],
759
                                       scale=self.hyp['scale'],
760
                                       shear=self.hyp['shear'],
761
                                       perspective=self.hyp['perspective'],
762
                                       border=self.mosaic_border)  # border to remove
763
764
    return img4, labels4
765
766
767
def load_mosaic9(self, index):
768
    # loads images in a 9-mosaic
769
770
    labels9, segments9 = [], []
771
    s = self.img_size
772
    indices = [index] + random.choices(self.indices, k=8)  # 8 additional image indices
773
    for i, index in enumerate(indices):
774
        # Load image
775
        img, _, (h, w) = load_image(self, index)
776
777
        # place img in img9
778
        if i == 0:  # center
779
            img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
780
            h0, w0 = h, w
781
            c = s, s, s + w, s + h  # xmin, ymin, xmax, ymax (base) coordinates
782
        elif i == 1:  # top
783
            c = s, s - h, s + w, s
784
        elif i == 2:  # top right
785
            c = s + wp, s - h, s + wp + w, s
786
        elif i == 3:  # right
787
            c = s + w0, s, s + w0 + w, s + h
788
        elif i == 4:  # bottom right
789
            c = s + w0, s + hp, s + w0 + w, s + hp + h
790
        elif i == 5:  # bottom
791
            c = s + w0 - w, s + h0, s + w0, s + h0 + h
792
        elif i == 6:  # bottom left
793
            c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
794
        elif i == 7:  # left
795
            c = s - w, s + h0 - h, s, s + h0
796
        elif i == 8:  # top left
797
            c = s - w, s + h0 - hp - h, s, s + h0 - hp
798
799
        padx, pady = c[:2]
800
        x1, y1, x2, y2 = [max(x, 0) for x in c]  # allocate coords
801
802
        # Labels
803
        labels, segments = self.labels[index].copy(), self.segments[index].copy()
804
        if labels.size:
805
            labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady)  # normalized xywh to pixel xyxy format
806
            segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
807
        labels9.append(labels)
808
        segments9.extend(segments)
809
810
        # Image
811
        img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:]  # img9[ymin:ymax, xmin:xmax]
812
        hp, wp = h, w  # height, width previous
813
814
    # Offset
815
    yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border]  # mosaic center x, y
816
    img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
817
818
    # Concat/clip labels
819
    labels9 = np.concatenate(labels9, 0)
820
    labels9[:, [1, 3]] -= xc
821
    labels9[:, [2, 4]] -= yc
822
    c = np.array([xc, yc])  # centers
823
    segments9 = [x - c for x in segments9]
824
825
    for x in (labels9[:, 1:], *segments9):
826
        np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
827
    # img9, labels9 = replicate(img9, labels9)  # replicate
828
829
    # Augment
830
    #img9, labels9, segments9 = remove_background(img9, labels9, segments9)
831
    img9, labels9, segments9 = copy_paste(img9, labels9, segments9, probability=self.hyp['copy_paste'])
832
    img9, labels9 = random_perspective(img9, labels9, segments9,
833
                                       degrees=self.hyp['degrees'],
834
                                       translate=self.hyp['translate'],
835
                                       scale=self.hyp['scale'],
836
                                       shear=self.hyp['shear'],
837
                                       perspective=self.hyp['perspective'],
838
                                       border=self.mosaic_border)  # border to remove
839
840
    return img9, labels9
841
842
843
def load_samples(self, index):
844
    # loads images in a 4-mosaic
845
846
    labels4, segments4 = [], []
847
    s = self.img_size
848
    yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border]  # mosaic center x, y
849
    indices = [index] + random.choices(self.indices, k=3)  # 3 additional image indices
850
    for i, index in enumerate(indices):
851
        # Load image
852
        img, _, (h, w) = load_image(self, index)
853
854
        # place img in img4
855
        if i == 0:  # top left
856
            img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
857
            x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
858
            x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
859
        elif i == 1:  # top right
860
            x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
861
            x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
862
        elif i == 2:  # bottom left
863
            x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
864
            x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
865
        elif i == 3:  # bottom right
866
            x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
867
            x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
868
869
        img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
870
        padw = x1a - x1b
871
        padh = y1a - y1b
872
873
        # Labels
874
        labels, segments = self.labels[index].copy(), self.segments[index].copy()
875
        if labels.size:
876
            labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh)  # normalized xywh to pixel xyxy format
877
            segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
878
        labels4.append(labels)
879
        segments4.extend(segments)
880
881
    # Concat/clip labels
882
    labels4 = np.concatenate(labels4, 0)
883
    for x in (labels4[:, 1:], *segments4):
884
        np.clip(x, 0, 2 * s, out=x)  # clip when using random_perspective()
885
    # img4, labels4 = replicate(img4, labels4)  # replicate
886
887
    # Augment
888
    #img4, labels4, segments4 = remove_background(img4, labels4, segments4)
889
    sample_labels, sample_images, sample_masks = sample_segments(img4, labels4, segments4, probability=0.5)
890
891
    return sample_labels, sample_images, sample_masks
892
893
894
def copy_paste(img, labels, segments, probability=0.5):
895
    # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
896
    n = len(segments)
897
    if probability and n:
898
        h, w, c = img.shape  # height, width, channels
899
        im_new = np.zeros(img.shape, np.uint8)
900
        for j in random.sample(range(n), k=round(probability * n)):
901
            l, s = labels[j], segments[j]
902
            box = w - l[3], l[2], w - l[1], l[4]
903
            ioa = bbox_ioa(box, labels[:, 1:5])  # intersection over area
904
            if (ioa < 0.30).all():  # allow 30% obscuration of existing labels
905
                labels = np.concatenate((labels, [[l[0], *box]]), 0)
906
                segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
907
                cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
908
909
        result = cv2.bitwise_and(src1=img, src2=im_new)
910
        result = cv2.flip(result, 1)  # augment segments (flip left-right)
911
        i = result > 0  # pixels to replace
912
        # i[:, :] = result.max(2).reshape(h, w, 1)  # act over ch
913
        img[i] = result[i]  # cv2.imwrite('debug.jpg', img)  # debug
914
915
    return img, labels, segments
916
917
918
def remove_background(img, labels, segments):
919
    # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
920
    n = len(segments)
921
    h, w, c = img.shape  # height, width, channels
922
    im_new = np.zeros(img.shape, np.uint8)
923
    img_new = np.ones(img.shape, np.uint8) * 114
924
    for j in range(n):
925
        cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
926
927
        result = cv2.bitwise_and(src1=img, src2=im_new)
928
        
929
        i = result > 0  # pixels to replace
930
        img_new[i] = result[i]  # cv2.imwrite('debug.jpg', img)  # debug
931
932
    return img_new, labels, segments
933
934
935
def sample_segments(img, labels, segments, probability=0.5):
936
    # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
937
    n = len(segments)
938
    sample_labels = []
939
    sample_images = []
940
    sample_masks = []
941
    if probability and n:
942
        h, w, c = img.shape  # height, width, channels
943
        for j in random.sample(range(n), k=round(probability * n)):
944
            l, s = labels[j], segments[j]
945
            box = l[1].astype(int).clip(0,w-1), l[2].astype(int).clip(0,h-1), l[3].astype(int).clip(0,w-1), l[4].astype(int).clip(0,h-1) 
946
            
947
            #print(box)
948
            if (box[2] <= box[0]) or (box[3] <= box[1]):
949
                continue
950
            
951
            sample_labels.append(l[0])
952
            
953
            mask = np.zeros(img.shape, np.uint8)
954
            
955
            cv2.drawContours(mask, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
956
            sample_masks.append(mask[box[1]:box[3],box[0]:box[2],:])
957
            
958
            result = cv2.bitwise_and(src1=img, src2=mask)
959
            i = result > 0  # pixels to replace
960
            mask[i] = result[i]  # cv2.imwrite('debug.jpg', img)  # debug
961
            #print(box)
962
            sample_images.append(mask[box[1]:box[3],box[0]:box[2],:])
963
964
    return sample_labels, sample_images, sample_masks
965
966
967
def replicate(img, labels):
968
    # Replicate labels
969
    h, w = img.shape[:2]
970
    boxes = labels[:, 1:].astype(int)
971
    x1, y1, x2, y2 = boxes.T
972
    s = ((x2 - x1) + (y2 - y1)) / 2  # side length (pixels)
973
    for i in s.argsort()[:round(s.size * 0.5)]:  # smallest indices
974
        x1b, y1b, x2b, y2b = boxes[i]
975
        bh, bw = y2b - y1b, x2b - x1b
976
        yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw))  # offset x, y
977
        x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
978
        img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
979
        labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
980
981
    return img, labels
982
983
984
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
985
    # Resize and pad image while meeting stride-multiple constraints
986
    shape = img.shape[:2]  # current shape [height, width]
987
    if isinstance(new_shape, int):
988
        new_shape = (new_shape, new_shape)
989
990
    # Scale ratio (new / old)
991
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
992
    if not scaleup:  # only scale down, do not scale up (for better test mAP)
993
        r = min(r, 1.0)
994
995
    # Compute padding
996
    ratio = r, r  # width, height ratios
997
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
998
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
999
    if auto:  # minimum rectangle
1000
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
1001
    elif scaleFill:  # stretch
1002
        dw, dh = 0.0, 0.0
1003
        new_unpad = (new_shape[1], new_shape[0])
1004
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios
1005
1006
    dw /= 2  # divide padding into 2 sides
1007
    dh /= 2
1008
1009
    if shape[::-1] != new_unpad:  # resize
1010
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
1011
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
1012
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
1013
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
1014
    return img, ratio, (dw, dh)
1015
1016
1017
def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
1018
                       border=(0, 0)):
1019
    # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
1020
    # targets = [cls, xyxy]
1021
1022
    height = img.shape[0] + border[0] * 2  # shape(h,w,c)
1023
    width = img.shape[1] + border[1] * 2
1024
1025
    # Center
1026
    C = np.eye(3)
1027
    C[0, 2] = -img.shape[1] / 2  # x translation (pixels)
1028
    C[1, 2] = -img.shape[0] / 2  # y translation (pixels)
1029
1030
    # Perspective
1031
    P = np.eye(3)
1032
    P[2, 0] = random.uniform(-perspective, perspective)  # x perspective (about y)
1033
    P[2, 1] = random.uniform(-perspective, perspective)  # y perspective (about x)
1034
1035
    # Rotation and Scale
1036
    R = np.eye(3)
1037
    a = random.uniform(-degrees, degrees)
1038
    # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations
1039
    s = random.uniform(1 - scale, 1.1 + scale)
1040
    # s = 2 ** random.uniform(-scale, scale)
1041
    R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
1042
1043
    # Shear
1044
    S = np.eye(3)
1045
    S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # x shear (deg)
1046
    S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # y shear (deg)
1047
1048
    # Translation
1049
    T = np.eye(3)
1050
    T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width  # x translation (pixels)
1051
    T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height  # y translation (pixels)
1052
1053
    # Combined rotation matrix
1054
    M = T @ S @ R @ P @ C  # order of operations (right to left) is IMPORTANT
1055
    if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any():  # image changed
1056
        if perspective:
1057
            img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
1058
        else:  # affine
1059
            img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
1060
1061
    # Visualize
1062
    # import matplotlib.pyplot as plt
1063
    # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
1064
    # ax[0].imshow(img[:, :, ::-1])  # base
1065
    # ax[1].imshow(img2[:, :, ::-1])  # warped
1066
1067
    # Transform label coordinates
1068
    n = len(targets)
1069
    if n:
1070
        use_segments = any(x.any() for x in segments)
1071
        new = np.zeros((n, 4))
1072
        if use_segments:  # warp segments
1073
            segments = resample_segments(segments)  # upsample
1074
            for i, segment in enumerate(segments):
1075
                xy = np.ones((len(segment), 3))
1076
                xy[:, :2] = segment
1077
                xy = xy @ M.T  # transform
1078
                xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]  # perspective rescale or affine
1079
1080
                # clip
1081
                new[i] = segment2box(xy, width, height)
1082
1083
        else:  # warp boxes
1084
            xy = np.ones((n * 4, 3))
1085
            xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1
1086
            xy = xy @ M.T  # transform
1087
            xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8)  # perspective rescale or affine
1088
1089
            # create new boxes
1090
            x = xy[:, [0, 2, 4, 6]]
1091
            y = xy[:, [1, 3, 5, 7]]
1092
            new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
1093
1094
            # clip
1095
            new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
1096
            new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
1097
1098
        # filter candidates
1099
        i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
1100
        targets = targets[i]
1101
        targets[:, 1:5] = new[i]
1102
1103
    return img, targets
1104
1105
1106
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16):  # box1(4,n), box2(4,n)
1107
    # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
1108
    w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
1109
    w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
1110
    ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps))  # aspect ratio
1111
    return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr)  # candidates
1112
1113
1114
def bbox_ioa(box1, box2):
1115
    # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
1116
    box2 = box2.transpose()
1117
1118
    # Get the coordinates of bounding boxes
1119
    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
1120
    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
1121
1122
    # Intersection area
1123
    inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
1124
                 (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
1125
1126
    # box2 area
1127
    box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
1128
1129
    # Intersection over box2 area
1130
    return inter_area / box2_area
1131
    
1132
1133
def cutout(image, labels):
1134
    # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
1135
    h, w = image.shape[:2]
1136
1137
    # create random masks
1138
    scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16  # image size fraction
1139
    for s in scales:
1140
        mask_h = random.randint(1, int(h * s))
1141
        mask_w = random.randint(1, int(w * s))
1142
1143
        # box
1144
        xmin = max(0, random.randint(0, w) - mask_w // 2)
1145
        ymin = max(0, random.randint(0, h) - mask_h // 2)
1146
        xmax = min(w, xmin + mask_w)
1147
        ymax = min(h, ymin + mask_h)
1148
1149
        # apply random color mask
1150
        image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
1151
1152
        # return unobscured labels
1153
        if len(labels) and s > 0.03:
1154
            box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
1155
            ioa = bbox_ioa(box, labels[:, 1:5])  # intersection over area
1156
            labels = labels[ioa < 0.60]  # remove >60% obscured labels
1157
1158
    return labels
1159
    
1160
1161
def pastein(image, labels, sample_labels, sample_images, sample_masks):
1162
    # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
1163
    h, w = image.shape[:2]
1164
1165
    # create random masks
1166
    scales = [0.75] * 2 + [0.5] * 4 + [0.25] * 4 + [0.125] * 4 + [0.0625] * 6  # image size fraction
1167
    for s in scales:
1168
        if random.random() < 0.2:
1169
            continue
1170
        mask_h = random.randint(1, int(h * s))
1171
        mask_w = random.randint(1, int(w * s))
1172
1173
        # box
1174
        xmin = max(0, random.randint(0, w) - mask_w // 2)
1175
        ymin = max(0, random.randint(0, h) - mask_h // 2)
1176
        xmax = min(w, xmin + mask_w)
1177
        ymax = min(h, ymin + mask_h)   
1178
        
1179
        box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
1180
        if len(labels):
1181
            ioa = bbox_ioa(box, labels[:, 1:5])  # intersection over area     
1182
        else:
1183
            ioa = np.zeros(1)
1184
        
1185
        if (ioa < 0.30).all() and len(sample_labels) and (xmax > xmin+20) and (ymax > ymin+20):  # allow 30% obscuration of existing labels
1186
            sel_ind = random.randint(0, len(sample_labels)-1)
1187
            #print(len(sample_labels))
1188
            #print(sel_ind)
1189
            #print((xmax-xmin, ymax-ymin))
1190
            #print(image[ymin:ymax, xmin:xmax].shape)
1191
            #print([[sample_labels[sel_ind], *box]])
1192
            #print(labels.shape)
1193
            hs, ws, cs = sample_images[sel_ind].shape
1194
            r_scale = min((ymax-ymin)/hs, (xmax-xmin)/ws)
1195
            r_w = int(ws*r_scale)
1196
            r_h = int(hs*r_scale)
1197
            
1198
            if (r_w > 10) and (r_h > 10):
1199
                r_mask = cv2.resize(sample_masks[sel_ind], (r_w, r_h))
1200
                r_image = cv2.resize(sample_images[sel_ind], (r_w, r_h))
1201
                temp_crop = image[ymin:ymin+r_h, xmin:xmin+r_w]
1202
                m_ind = r_mask > 0
1203
                if m_ind.astype(np.int).sum() > 60:
1204
                    temp_crop[m_ind] = r_image[m_ind]
1205
                    #print(sample_labels[sel_ind])
1206
                    #print(sample_images[sel_ind].shape)
1207
                    #print(temp_crop.shape)
1208
                    box = np.array([xmin, ymin, xmin+r_w, ymin+r_h], dtype=np.float32)
1209
                    if len(labels):
1210
                        labels = np.concatenate((labels, [[sample_labels[sel_ind], *box]]), 0)
1211
                    else:
1212
                        labels = np.array([[sample_labels[sel_ind], *box]])
1213
                              
1214
                    image[ymin:ymin+r_h, xmin:xmin+r_w] = temp_crop
1215
1216
    return labels
1217
1218
class Albumentations:
1219
    # YOLOv5 Albumentations class (optional, only used if package is installed)
1220
    def __init__(self):
1221
        self.transform = None
1222
        import albumentations as A
1223
1224
        self.transform = A.Compose([
1225
            A.CLAHE(p=0.01),
1226
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.01),
1227
            A.RandomGamma(gamma_limit=[80, 120], p=0.01),
1228
            A.Blur(p=0.01),
1229
            A.MedianBlur(p=0.01),
1230
            A.ToGray(p=0.01),
1231
            A.ImageCompression(quality_lower=75, p=0.01),],
1232
            bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
1233
1234
            #logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
1235
1236
    def __call__(self, im, labels, p=1.0):
1237
        if self.transform and random.random() < p:
1238
            new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0])  # transformed
1239
            im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
1240
        return im, labels
1241
1242
1243
def create_folder(path='./new'):
1244
    # Create folder
1245
    if os.path.exists(path):
1246
        shutil.rmtree(path)  # delete output folder
1247
    os.makedirs(path)  # make new output folder
1248
1249
1250
def flatten_recursive(path='../coco'):
1251
    # Flatten a recursive directory by bringing all files to top level
1252
    new_path = Path(path + '_flat')
1253
    create_folder(new_path)
1254
    for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
1255
        shutil.copyfile(file, new_path / Path(file).name)
1256
1257
1258
def extract_boxes(path='../coco/'):  # from utils.datasets import *; extract_boxes('../coco128')
1259
    # Convert detection dataset into classification dataset, with one directory per class
1260
1261
    path = Path(path)  # images dir
1262
    shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None  # remove existing
1263
    files = list(path.rglob('*.*'))
1264
    n = len(files)  # number of files
1265
    for im_file in tqdm(files, total=n):
1266
        if im_file.suffix[1:] in img_formats:
1267
            # image
1268
            im = cv2.imread(str(im_file))[..., ::-1]  # BGR to RGB
1269
            h, w = im.shape[:2]
1270
1271
            # labels
1272
            lb_file = Path(img2label_paths([str(im_file)])[0])
1273
            if Path(lb_file).exists():
1274
                with open(lb_file, 'r') as f:
1275
                    lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32)  # labels
1276
1277
                for j, x in enumerate(lb):
1278
                    c = int(x[0])  # class
1279
                    f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg'  # new filename
1280
                    if not f.parent.is_dir():
1281
                        f.parent.mkdir(parents=True)
1282
1283
                    b = x[1:] * [w, h, w, h]  # box
1284
                    # b[2:] = b[2:].max()  # rectangle to square
1285
                    b[2:] = b[2:] * 1.2 + 3  # pad
1286
                    b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
1287
1288
                    b[[0, 2]] = np.clip(b[[0, 2]], 0, w)  # clip boxes outside of image
1289
                    b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
1290
                    assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
1291
1292
1293
def autosplit(path='../coco', weights=(0.9, 0.1, 0.0), annotated_only=False):
1294
    """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
1295
    Usage: from utils.datasets import *; autosplit('../coco')
1296
    Arguments
1297
        path:           Path to images directory
1298
        weights:        Train, val, test weights (list)
1299
        annotated_only: Only use images with an annotated txt file
1300
    """
1301
    path = Path(path)  # images dir
1302
    files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], [])  # image files only
1303
    n = len(files)  # number of files
1304
    indices = random.choices([0, 1, 2], weights=weights, k=n)  # assign each image to a split
1305
1306
    txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt']  # 3 txt files
1307
    [(path / x).unlink() for x in txt if (path / x).exists()]  # remove existing
1308
1309
    print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
1310
    for i, img in tqdm(zip(indices, files), total=n):
1311
        if not annotated_only or Path(img2label_paths([str(img)])[0]).exists():  # check label
1312
            with open(path / txt[i], 'a') as f:
1313
                f.write(str(img) + '\n')  # add image to txt file
1314
    
1315
    
1316
def load_segmentations(self, index):
1317
    key = '/work/handsomejw66/coco17/' + self.img_files[index]
1318
    #print(key)
1319
    # /work/handsomejw66/coco17/
1320
    return self.segs[key]