a b/segment/train.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Train a YOLOv5 segment model on a segment dataset
4
Models and datasets download automatically from the latest YOLOv5 release.
5
6
Usage - Single-GPU training:
7
    $ python segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640  # from pretrained (recommended)
8
    $ python segment/train.py --data coco128-seg.yaml --weights '' --cfg yolov5s-seg.yaml --img 640  # from scratch
9
10
Usage - Multi-GPU DDP training:
11
    $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 --device 0,1,2,3
12
13
Models:     https://github.com/ultralytics/yolov5/tree/master/models
14
Datasets:   https://github.com/ultralytics/yolov5/tree/master/data
15
Tutorial:   https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
16
"""
17
18
import argparse
19
import math
20
import os
21
import random
22
import subprocess
23
import sys
24
import time
25
from copy import deepcopy
26
from datetime import datetime
27
from pathlib import Path
28
29
import numpy as np
30
import torch
31
import torch.distributed as dist
32
import torch.nn as nn
33
import yaml
34
from torch.optim import lr_scheduler
35
from tqdm import tqdm
36
37
FILE = Path(__file__).resolve()
38
ROOT = FILE.parents[1]  # YOLOv5 root directory
39
if str(ROOT) not in sys.path:
40
    sys.path.append(str(ROOT))  # add ROOT to PATH
41
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
42
43
import segment.val as validate  # for end-of-epoch mAP
44
from models.experimental import attempt_load
45
from models.yolo import SegmentationModel
46
from utils.autoanchor import check_anchors
47
from utils.autobatch import check_train_batch_size
48
from utils.callbacks import Callbacks
49
from utils.downloads import attempt_download, is_url
50
from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
51
                           check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
52
                           get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
53
                           labels_to_image_weights, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save)
54
from utils.loggers import GenericLogger
55
from utils.plots import plot_evolve, plot_labels
56
from utils.segment.dataloaders import create_dataloader
57
from utils.segment.loss import ComputeLoss
58
from utils.segment.metrics import KEYS, fitness
59
from utils.segment.plots import plot_images_and_masks, plot_results_with_masks
60
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
61
                               smart_resume, torch_distributed_zero_first)
62
63
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
64
RANK = int(os.getenv('RANK', -1))
65
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
66
GIT_INFO = check_git_info()
67
68
69
def train(hyp, opt, device, callbacks):  # hyp is path/to/hyp.yaml or hyp dictionary
70
    save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, mask_ratio = \
71
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
72
        opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.mask_ratio
73
    # callbacks.run('on_pretrain_routine_start')
74
75
    # Directories
76
    w = save_dir / 'weights'  # weights dir
77
    (w.parent if evolve else w).mkdir(parents=True, exist_ok=True)  # make dir
78
    last, best = w / 'last.pt', w / 'best.pt'
79
80
    # Hyperparameters
81
    if isinstance(hyp, str):
82
        with open(hyp, errors='ignore') as f:
83
            hyp = yaml.safe_load(f)  # load hyps dict
84
    LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
85
    opt.hyp = hyp.copy()  # for saving hyps to checkpoints
86
87
    # Save run settings
88
    if not evolve:
89
        yaml_save(save_dir / 'hyp.yaml', hyp)
90
        yaml_save(save_dir / 'opt.yaml', vars(opt))
91
92
    # Loggers
93
    data_dict = None
94
    if RANK in {-1, 0}:
95
        logger = GenericLogger(opt=opt, console_logger=LOGGER)
96
97
    # Config
98
    plots = not evolve and not opt.noplots  # create plots
99
    overlap = not opt.no_overlap
100
    cuda = device.type != 'cpu'
101
    init_seeds(opt.seed + 1 + RANK, deterministic=True)
102
    with torch_distributed_zero_first(LOCAL_RANK):
103
        data_dict = data_dict or check_dataset(data)  # check if None
104
    train_path, val_path = data_dict['train'], data_dict['val']
105
    nc = 1 if single_cls else int(data_dict['nc'])  # number of classes
106
    names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names']  # class names
107
    is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt')  # COCO dataset
108
109
    # Model
110
    check_suffix(weights, '.pt')  # check weights
111
    pretrained = weights.endswith('.pt')
112
    if pretrained:
113
        with torch_distributed_zero_first(LOCAL_RANK):
114
            weights = attempt_download(weights)  # download if not found locally
115
        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
116
        model = SegmentationModel(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
117
        exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
118
        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
119
        csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect
120
        model.load_state_dict(csd, strict=False)  # load
121
        LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  # report
122
    else:
123
        model = SegmentationModel(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
124
    amp = check_amp(model)  # check AMP
125
126
    # Freeze
127
    freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]  # layers to freeze
128
    for k, v in model.named_parameters():
129
        v.requires_grad = True  # train all layers
130
        # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
131
        if any(x in k for x in freeze):
132
            LOGGER.info(f'freezing {k}')
133
            v.requires_grad = False
134
135
    # Image size
136
    gs = max(int(model.stride.max()), 32)  # grid size (max stride)
137
    imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2)  # verify imgsz is gs-multiple
138
139
    # Batch size
140
    if RANK == -1 and batch_size == -1:  # single-GPU only, estimate best batch size
141
        batch_size = check_train_batch_size(model, imgsz, amp)
142
        logger.update_params({'batch_size': batch_size})
143
        # loggers.on_params_update({"batch_size": batch_size})
144
145
    # Optimizer
146
    nbs = 64  # nominal batch size
147
    accumulate = max(round(nbs / batch_size), 1)  # accumulate loss before optimizing
148
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
149
    optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
150
151
    # Scheduler
152
    if opt.cos_lr:
153
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
154
    else:
155
        lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']  # linear
156
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)
157
158
    # EMA
159
    ema = ModelEMA(model) if RANK in {-1, 0} else None
160
161
    # Resume
162
    best_fitness, start_epoch = 0.0, 0
163
    if pretrained:
164
        if resume:
165
            best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
166
        del ckpt, csd
167
168
    # DP mode
169
    if cuda and RANK == -1 and torch.cuda.device_count() > 1:
170
        LOGGER.warning(
171
            'WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
172
            'See Multi-GPU Tutorial at https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training to get started.'
173
        )
174
        model = torch.nn.DataParallel(model)
175
176
    # SyncBatchNorm
177
    if opt.sync_bn and cuda and RANK != -1:
178
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
179
        LOGGER.info('Using SyncBatchNorm()')
180
181
    # Trainloader
182
    train_loader, dataset = create_dataloader(
183
        train_path,
184
        imgsz,
185
        batch_size // WORLD_SIZE,
186
        gs,
187
        single_cls,
188
        hyp=hyp,
189
        augment=True,
190
        cache=None if opt.cache == 'val' else opt.cache,
191
        rect=opt.rect,
192
        rank=LOCAL_RANK,
193
        workers=workers,
194
        image_weights=opt.image_weights,
195
        quad=opt.quad,
196
        prefix=colorstr('train: '),
197
        shuffle=True,
198
        mask_downsample_ratio=mask_ratio,
199
        overlap_mask=overlap,
200
    )
201
    labels = np.concatenate(dataset.labels, 0)
202
    mlc = int(labels[:, 0].max())  # max label class
203
    assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
204
205
    # Process 0
206
    if RANK in {-1, 0}:
207
        val_loader = create_dataloader(val_path,
208
                                       imgsz,
209
                                       batch_size // WORLD_SIZE * 2,
210
                                       gs,
211
                                       single_cls,
212
                                       hyp=hyp,
213
                                       cache=None if noval else opt.cache,
214
                                       rect=True,
215
                                       rank=-1,
216
                                       workers=workers * 2,
217
                                       pad=0.5,
218
                                       mask_downsample_ratio=mask_ratio,
219
                                       overlap_mask=overlap,
220
                                       prefix=colorstr('val: '))[0]
221
222
        if not resume:
223
            if not opt.noautoanchor:
224
                check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)  # run AutoAnchor
225
            model.half().float()  # pre-reduce anchor precision
226
227
            if plots:
228
                plot_labels(labels, names, save_dir)
229
        # callbacks.run('on_pretrain_routine_end', labels, names)
230
231
    # DDP mode
232
    if cuda and RANK != -1:
233
        model = smart_DDP(model)
234
235
    # Model attributes
236
    nl = de_parallel(model).model[-1].nl  # number of detection layers (to scale hyps)
237
    hyp['box'] *= 3 / nl  # scale to layers
238
    hyp['cls'] *= nc / 80 * 3 / nl  # scale to classes and layers
239
    hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
240
    hyp['label_smoothing'] = opt.label_smoothing
241
    model.nc = nc  # attach number of classes to model
242
    model.hyp = hyp  # attach hyperparameters to model
243
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
244
    model.names = names
245
246
    # Start training
247
    t0 = time.time()
248
    nb = len(train_loader)  # number of batches
249
    nw = max(round(hyp['warmup_epochs'] * nb), 100)  # number of warmup iterations, max(3 epochs, 100 iterations)
250
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
251
    last_opt_step = -1
252
    maps = np.zeros(nc)  # mAP per class
253
    results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)  # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
254
    scheduler.last_epoch = start_epoch - 1  # do not move
255
    scaler = torch.cuda.amp.GradScaler(enabled=amp)
256
    stopper, stop = EarlyStopping(patience=opt.patience), False
257
    compute_loss = ComputeLoss(model, overlap=overlap)  # init loss class
258
    # callbacks.run('on_train_start')
259
    LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
260
                f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
261
                f"Logging results to {colorstr('bold', save_dir)}\n"
262
                f'Starting training for {epochs} epochs...')
263
    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
264
        # callbacks.run('on_train_epoch_start')
265
        model.train()
266
267
        # Update image weights (optional, single-GPU only)
268
        if opt.image_weights:
269
            cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
270
            iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
271
            dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx
272
273
        # Update mosaic border (optional)
274
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
275
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders
276
277
        mloss = torch.zeros(4, device=device)  # mean losses
278
        if RANK != -1:
279
            train_loader.sampler.set_epoch(epoch)
280
        pbar = enumerate(train_loader)
281
        LOGGER.info(('\n' + '%11s' * 8) %
282
                    ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Instances', 'Size'))
283
        if RANK in {-1, 0}:
284
            pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT)  # progress bar
285
        optimizer.zero_grad()
286
        for i, (imgs, targets, paths, _, masks) in pbar:  # batch ------------------------------------------------------
287
            # callbacks.run('on_train_batch_start')
288
            ni = i + nb * epoch  # number integrated batches (since train start)
289
            imgs = imgs.to(device, non_blocking=True).float() / 255  # uint8 to float32, 0-255 to 0.0-1.0
290
291
            # Warmup
292
            if ni <= nw:
293
                xi = [0, nw]  # x interp
294
                # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
295
                accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
296
                for j, x in enumerate(optimizer.param_groups):
297
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
298
                    x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
299
                    if 'momentum' in x:
300
                        x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
301
302
            # Multi-scale
303
            if opt.multi_scale:
304
                sz = random.randrange(int(imgsz * 0.5), int(imgsz * 1.5) + gs) // gs * gs  # size
305
                sf = sz / max(imgs.shape[2:])  # scale factor
306
                if sf != 1:
307
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple)
308
                    imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
309
310
            # Forward
311
            with torch.cuda.amp.autocast(amp):
312
                pred = model(imgs)  # forward
313
                loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
314
                if RANK != -1:
315
                    loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
316
                if opt.quad:
317
                    loss *= 4.
318
319
            # Backward
320
            scaler.scale(loss).backward()
321
322
            # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
323
            if ni - last_opt_step >= accumulate:
324
                scaler.unscale_(optimizer)  # unscale gradients
325
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)  # clip gradients
326
                scaler.step(optimizer)  # optimizer.step
327
                scaler.update()
328
                optimizer.zero_grad()
329
                if ema:
330
                    ema.update(model)
331
                last_opt_step = ni
332
333
            # Log
334
            if RANK in {-1, 0}:
335
                mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
336
                mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
337
                pbar.set_description(('%11s' * 2 + '%11.4g' * 6) %
338
                                     (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
339
                # callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
340
                # if callbacks.stop_training:
341
                #    return
342
343
                # Mosaic plots
344
                if plots:
345
                    if ni < 3:
346
                        plot_images_and_masks(imgs, targets, masks, paths, save_dir / f'train_batch{ni}.jpg')
347
                    if ni == 10:
348
                        files = sorted(save_dir.glob('train*.jpg'))
349
                        logger.log_images(files, 'Mosaics', epoch)
350
            # end batch ------------------------------------------------------------------------------------------------
351
352
        # Scheduler
353
        lr = [x['lr'] for x in optimizer.param_groups]  # for loggers
354
        scheduler.step()
355
356
        if RANK in {-1, 0}:
357
            # mAP
358
            # callbacks.run('on_train_epoch_end', epoch=epoch)
359
            ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
360
            final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
361
            if not noval or final_epoch:  # Calculate mAP
362
                results, maps, _ = validate.run(data_dict,
363
                                                batch_size=batch_size // WORLD_SIZE * 2,
364
                                                imgsz=imgsz,
365
                                                half=amp,
366
                                                model=ema.ema,
367
                                                single_cls=single_cls,
368
                                                dataloader=val_loader,
369
                                                save_dir=save_dir,
370
                                                plots=False,
371
                                                callbacks=callbacks,
372
                                                compute_loss=compute_loss,
373
                                                mask_downsample_ratio=mask_ratio,
374
                                                overlap=overlap)
375
376
            # Update best mAP
377
            fi = fitness(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
378
            stop = stopper(epoch=epoch, fitness=fi)  # early stop check
379
            if fi > best_fitness:
380
                best_fitness = fi
381
            log_vals = list(mloss) + list(results) + lr
382
            # callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
383
            # Log val metrics and media
384
            metrics_dict = dict(zip(KEYS, log_vals))
385
            logger.log_metrics(metrics_dict, epoch)
386
387
            # Save model
388
            if (not nosave) or (final_epoch and not evolve):  # if save
389
                ckpt = {
390
                    'epoch': epoch,
391
                    'best_fitness': best_fitness,
392
                    'model': deepcopy(de_parallel(model)).half(),
393
                    'ema': deepcopy(ema.ema).half(),
394
                    'updates': ema.updates,
395
                    'optimizer': optimizer.state_dict(),
396
                    'opt': vars(opt),
397
                    'git': GIT_INFO,  # {remote, branch, commit} if a git repo
398
                    'date': datetime.now().isoformat()}
399
400
                # Save last, best and delete
401
                torch.save(ckpt, last)
402
                if best_fitness == fi:
403
                    torch.save(ckpt, best)
404
                if opt.save_period > 0 and epoch % opt.save_period == 0:
405
                    torch.save(ckpt, w / f'epoch{epoch}.pt')
406
                    logger.log_model(w / f'epoch{epoch}.pt')
407
                del ckpt
408
                # callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
409
410
        # EarlyStopping
411
        if RANK != -1:  # if DDP training
412
            broadcast_list = [stop if RANK == 0 else None]
413
            dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
414
            if RANK != 0:
415
                stop = broadcast_list[0]
416
        if stop:
417
            break  # must break all DDP ranks
418
419
        # end epoch ----------------------------------------------------------------------------------------------------
420
    # end training -----------------------------------------------------------------------------------------------------
421
    if RANK in {-1, 0}:
422
        LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
423
        for f in last, best:
424
            if f.exists():
425
                strip_optimizer(f)  # strip optimizers
426
                if f is best:
427
                    LOGGER.info(f'\nValidating {f}...')
428
                    results, _, _ = validate.run(
429
                        data_dict,
430
                        batch_size=batch_size // WORLD_SIZE * 2,
431
                        imgsz=imgsz,
432
                        model=attempt_load(f, device).half(),
433
                        iou_thres=0.65 if is_coco else 0.60,  # best pycocotools at iou 0.65
434
                        single_cls=single_cls,
435
                        dataloader=val_loader,
436
                        save_dir=save_dir,
437
                        save_json=is_coco,
438
                        verbose=True,
439
                        plots=plots,
440
                        callbacks=callbacks,
441
                        compute_loss=compute_loss,
442
                        mask_downsample_ratio=mask_ratio,
443
                        overlap=overlap)  # val best model with plots
444
                    if is_coco:
445
                        # callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
446
                        metrics_dict = dict(zip(KEYS, list(mloss) + list(results) + lr))
447
                        logger.log_metrics(metrics_dict, epoch)
448
449
        # callbacks.run('on_train_end', last, best, epoch, results)
450
        # on train end callback using genericLogger
451
        logger.log_metrics(dict(zip(KEYS[4:16], results)), epochs)
452
        if not opt.evolve:
453
            logger.log_model(best, epoch)
454
        if plots:
455
            plot_results_with_masks(file=save_dir / 'results.csv')  # save results.png
456
            files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
457
            files = [(save_dir / f) for f in files if (save_dir / f).exists()]  # filter
458
            LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
459
            logger.log_images(files, 'Results', epoch + 1)
460
            logger.log_images(sorted(save_dir.glob('val*.jpg')), 'Validation', epoch + 1)
461
    torch.cuda.empty_cache()
462
    return results
463
464
465
def parse_opt(known=False):
466
    parser = argparse.ArgumentParser()
467
    parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s-seg.pt', help='initial weights path')
468
    parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
469
    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path')
470
    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
471
    parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
472
    parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
473
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
474
    parser.add_argument('--rect', action='store_true', help='rectangular training')
475
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
476
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
477
    parser.add_argument('--noval', action='store_true', help='only validate final epoch')
478
    parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
479
    parser.add_argument('--noplots', action='store_true', help='save no plot files')
480
    parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
481
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
482
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
483
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
484
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
485
    parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
486
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
487
    parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
488
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
489
    parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
490
    parser.add_argument('--project', default=ROOT / 'runs/train-seg', help='save to project/name')
491
    parser.add_argument('--name', default='exp', help='save to project/name')
492
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
493
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
494
    parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
495
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
496
    parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
497
    parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
498
    parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
499
    parser.add_argument('--seed', type=int, default=0, help='Global training seed')
500
    parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
501
502
    # Instance Segmentation Args
503
    parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory')
504
    parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP')
505
506
    return parser.parse_known_args()[0] if known else parser.parse_args()
507
508
509
def main(opt, callbacks=Callbacks()):
510
    # Checks
511
    if RANK in {-1, 0}:
512
        print_args(vars(opt))
513
        check_git_status()
514
        check_requirements(ROOT / 'requirements.txt')
515
516
    # Resume
517
    if opt.resume and not opt.evolve:  # resume from specified or most recent last.pt
518
        last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
519
        opt_yaml = last.parent.parent / 'opt.yaml'  # train options yaml
520
        opt_data = opt.data  # original dataset
521
        if opt_yaml.is_file():
522
            with open(opt_yaml, errors='ignore') as f:
523
                d = yaml.safe_load(f)
524
        else:
525
            d = torch.load(last, map_location='cpu')['opt']
526
        opt = argparse.Namespace(**d)  # replace
527
        opt.cfg, opt.weights, opt.resume = '', str(last), True  # reinstate
528
        if is_url(opt_data):
529
            opt.data = check_file(opt_data)  # avoid HUB resume auth timeout
530
    else:
531
        opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
532
            check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project)  # checks
533
        assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
534
        if opt.evolve:
535
            if opt.project == str(ROOT / 'runs/train-seg'):  # if default project name, rename to runs/evolve-seg
536
                opt.project = str(ROOT / 'runs/evolve-seg')
537
            opt.exist_ok, opt.resume = opt.resume, False  # pass resume to exist_ok and disable resume
538
        if opt.name == 'cfg':
539
            opt.name = Path(opt.cfg).stem  # use model.yaml as name
540
        opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
541
542
    # DDP mode
543
    device = select_device(opt.device, batch_size=opt.batch_size)
544
    if LOCAL_RANK != -1:
545
        msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
546
        assert not opt.image_weights, f'--image-weights {msg}'
547
        assert not opt.evolve, f'--evolve {msg}'
548
        assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
549
        assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
550
        assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
551
        torch.cuda.set_device(LOCAL_RANK)
552
        device = torch.device('cuda', LOCAL_RANK)
553
        dist.init_process_group(backend='nccl' if dist.is_nccl_available() else 'gloo')
554
555
    # Train
556
    if not opt.evolve:
557
        train(opt.hyp, opt, device, callbacks)
558
559
    # Evolve hyperparameters (optional)
560
    else:
561
        # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
562
        meta = {
563
            'lr0': (1, 1e-5, 1e-1),  # initial learning rate (SGD=1E-2, Adam=1E-3)
564
            'lrf': (1, 0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
565
            'momentum': (0.3, 0.6, 0.98),  # SGD momentum/Adam beta1
566
            'weight_decay': (1, 0.0, 0.001),  # optimizer weight decay
567
            'warmup_epochs': (1, 0.0, 5.0),  # warmup epochs (fractions ok)
568
            'warmup_momentum': (1, 0.0, 0.95),  # warmup initial momentum
569
            'warmup_bias_lr': (1, 0.0, 0.2),  # warmup initial bias lr
570
            'box': (1, 0.02, 0.2),  # box loss gain
571
            'cls': (1, 0.2, 4.0),  # cls loss gain
572
            'cls_pw': (1, 0.5, 2.0),  # cls BCELoss positive_weight
573
            'obj': (1, 0.2, 4.0),  # obj loss gain (scale with pixels)
574
            'obj_pw': (1, 0.5, 2.0),  # obj BCELoss positive_weight
575
            'iou_t': (0, 0.1, 0.7),  # IoU training threshold
576
            'anchor_t': (1, 2.0, 8.0),  # anchor-multiple threshold
577
            'anchors': (2, 2.0, 10.0),  # anchors per output grid (0 to ignore)
578
            'fl_gamma': (0, 0.0, 2.0),  # focal loss gamma (efficientDet default gamma=1.5)
579
            'hsv_h': (1, 0.0, 0.1),  # image HSV-Hue augmentation (fraction)
580
            'hsv_s': (1, 0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
581
            'hsv_v': (1, 0.0, 0.9),  # image HSV-Value augmentation (fraction)
582
            'degrees': (1, 0.0, 45.0),  # image rotation (+/- deg)
583
            'translate': (1, 0.0, 0.9),  # image translation (+/- fraction)
584
            'scale': (1, 0.0, 0.9),  # image scale (+/- gain)
585
            'shear': (1, 0.0, 10.0),  # image shear (+/- deg)
586
            'perspective': (0, 0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
587
            'flipud': (1, 0.0, 1.0),  # image flip up-down (probability)
588
            'fliplr': (0, 0.0, 1.0),  # image flip left-right (probability)
589
            'mosaic': (1, 0.0, 1.0),  # image mixup (probability)
590
            'mixup': (1, 0.0, 1.0),  # image mixup (probability)
591
            'copy_paste': (1, 0.0, 1.0)}  # segment copy-paste (probability)
592
593
        with open(opt.hyp, errors='ignore') as f:
594
            hyp = yaml.safe_load(f)  # load hyps dict
595
            if 'anchors' not in hyp:  # anchors commented in hyp.yaml
596
                hyp['anchors'] = 3
597
        if opt.noautoanchor:
598
            del hyp['anchors'], meta['anchors']
599
        opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir)  # only val/save final epoch
600
        # ei = [isinstance(x, (int, float)) for x in hyp.values()]  # evolvable indices
601
        evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
602
        if opt.bucket:
603
            # download evolve.csv if exists
604
            subprocess.run([
605
                'gsutil',
606
                'cp',
607
                f'gs://{opt.bucket}/evolve.csv',
608
                str(evolve_csv), ])
609
610
        for _ in range(opt.evolve):  # generations to evolve
611
            if evolve_csv.exists():  # if evolve.csv exists: select best hyps and mutate
612
                # Select parent(s)
613
                parent = 'single'  # parent selection method: 'single' or 'weighted'
614
                x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
615
                n = min(5, len(x))  # number of previous results to consider
616
                x = x[np.argsort(-fitness(x))][:n]  # top n mutations
617
                w = fitness(x) - fitness(x).min() + 1E-6  # weights (sum > 0)
618
                if parent == 'single' or len(x) == 1:
619
                    # x = x[random.randint(0, n - 1)]  # random selection
620
                    x = x[random.choices(range(n), weights=w)[0]]  # weighted selection
621
                elif parent == 'weighted':
622
                    x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination
623
624
                # Mutate
625
                mp, s = 0.8, 0.2  # mutation probability, sigma
626
                npr = np.random
627
                npr.seed(int(time.time()))
628
                g = np.array([meta[k][0] for k in hyp.keys()])  # gains 0-1
629
                ng = len(meta)
630
                v = np.ones(ng)
631
                while all(v == 1):  # mutate until a change occurs (prevent duplicates)
632
                    v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
633
                for i, k in enumerate(hyp.keys()):  # plt.hist(v.ravel(), 300)
634
                    hyp[k] = float(x[i + 12] * v[i])  # mutate
635
636
            # Constrain to limits
637
            for k, v in meta.items():
638
                hyp[k] = max(hyp[k], v[1])  # lower limit
639
                hyp[k] = min(hyp[k], v[2])  # upper limit
640
                hyp[k] = round(hyp[k], 5)  # significant digits
641
642
            # Train mutation
643
            results = train(hyp.copy(), opt, device, callbacks)
644
            callbacks = Callbacks()
645
            # Write mutation results
646
            print_mutation(KEYS[4:16], results, hyp.copy(), save_dir, opt.bucket)
647
648
        # Plot results
649
        plot_evolve(evolve_csv)
650
        LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
651
                    f"Results saved to {colorstr('bold', save_dir)}\n"
652
                    f'Usage example: $ python train.py --hyp {evolve_yaml}')
653
654
655
def run(**kwargs):
656
    # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
657
    opt = parse_opt(True)
658
    for k, v in kwargs.items():
659
        setattr(opt, k, v)
660
    main(opt)
661
    return opt
662
663
664
if __name__ == '__main__':
665
    opt = parse_opt()
666
    main(opt)