Diff of /yolov5/train.py [000000] .. [f26a44]

Switch to unified view

a b/yolov5/train.py
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
Train a YOLOv5 model on a custom dataset
4
5
Usage:
6
    $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
7
"""
8
import argparse
9
import math
10
import os
11
import random
12
import sys
13
import time
14
from copy import deepcopy
15
from datetime import datetime
16
from pathlib import Path
17
18
import numpy as np
19
import torch
20
import torch.distributed as dist
21
import torch.nn as nn
22
import yaml
23
from torch.cuda import amp
24
from torch.nn.parallel import DistributedDataParallel as DDP
25
from torch.optim import SGD, Adam, lr_scheduler
26
from tqdm import tqdm
27
28
FILE = Path(__file__).resolve()
29
ROOT = FILE.parents[0]  # YOLOv5 root directory
30
if str(ROOT) not in sys.path:
31
    sys.path.append(str(ROOT))  # add ROOT to PATH
32
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
33
34
import val  # for end-of-epoch mAP
35
from models.experimental import attempt_load
36
from models.yolo import Model
37
from utils.autoanchor import check_anchors
38
from utils.autobatch import check_train_batch_size
39
from utils.callbacks import Callbacks
40
from utils.datasets import create_dataloader
41
from utils.downloads import attempt_download
42
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
43
                           check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
44
                           intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
45
                           print_args, print_mutation, strip_optimizer)
46
from utils.loggers import Loggers
47
from utils.loggers.wandb.wandb_utils import check_wandb_resume
48
from utils.loss import ComputeLoss
49
from utils.metrics import fitness
50
from utils.plots import plot_evolve, plot_labels
51
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
52
53
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
54
RANK = int(os.getenv('RANK', -1))
55
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
56
57
58
def train(hyp,  # path/to/hyp.yaml or hyp dictionary
59
          opt,
60
          device,
61
          callbacks
62
          ):
63
    save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
64
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
65
        opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
66
67
    # Directories
68
    w = save_dir / 'weights'  # weights dir
69
    (w.parent if evolve else w).mkdir(parents=True, exist_ok=True)  # make dir
70
    last, best = w / 'last.pt', w / 'best.pt'
71
72
    # Hyperparameters
73
    if isinstance(hyp, str):
74
        with open(hyp, errors='ignore') as f:
75
            hyp = yaml.safe_load(f)  # load hyps dict
76
    LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
77
78
    # Save run settings
79
    if not evolve:
80
        with open(save_dir / 'hyp.yaml', 'w') as f:
81
            yaml.safe_dump(hyp, f, sort_keys=False)
82
        with open(save_dir / 'opt.yaml', 'w') as f:
83
            yaml.safe_dump(vars(opt), f, sort_keys=False)
84
85
    # Loggers
86
    data_dict = None
87
    if RANK in [-1, 0]:
88
        loggers = Loggers(save_dir, weights, opt, hyp, LOGGER)  # loggers instance
89
        if loggers.wandb:
90
            data_dict = loggers.wandb.data_dict
91
            if resume:
92
                weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
93
94
        # Register actions
95
        for k in methods(loggers):
96
            callbacks.register_action(k, callback=getattr(loggers, k))
97
98
    # Config
99
    plots = not evolve  # create plots
100
    cuda = device.type != 'cpu'
101
    init_seeds(1 + RANK)
102
    with torch_distributed_zero_first(LOCAL_RANK):
103
        data_dict = data_dict or check_dataset(data)  # check if None
104
    train_path, val_path = data_dict['train'], data_dict['val']
105
    nc = 1 if single_cls else int(data_dict['nc'])  # number of classes
106
    names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names']  # class names
107
    assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}'  # check
108
    is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt')  # COCO dataset
109
110
    # Model
111
    check_suffix(weights, '.pt')  # check weights
112
    pretrained = weights.endswith('.pt')
113
    if pretrained:
114
        with torch_distributed_zero_first(LOCAL_RANK):
115
            weights = attempt_download(weights)  # download if not found locally
116
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
117
        model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
118
        exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
119
        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
120
        csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect
121
        model.load_state_dict(csd, strict=False)  # load
122
        LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  # report
123
    else:
124
        model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
125
126
    # Freeze
127
    freeze = [f'model.{x}.' for x in range(freeze)]  # layers to freeze
128
    for k, v in model.named_parameters():
129
        v.requires_grad = True  # train all layers
130
        if any(x in k for x in freeze):
131
            LOGGER.info(f'freezing {k}')
132
            v.requires_grad = False
133
134
    # Image size
135
    gs = max(int(model.stride.max()), 32)  # grid size (max stride)
136
    imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2)  # verify imgsz is gs-multiple
137
138
    # Batch size
139
    if RANK == -1 and batch_size == -1:  # single-GPU only, estimate best batch size
140
        batch_size = check_train_batch_size(model, imgsz)
141
142
    # Optimizer
143
    nbs = 64  # nominal batch size
144
    accumulate = max(round(nbs / batch_size), 1)  # accumulate loss before optimizing
145
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
146
    LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
147
148
    g0, g1, g2 = [], [], []  # optimizer parameter groups
149
    for v in model.modules():
150
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias
151
            g2.append(v.bias)
152
        if isinstance(v, nn.BatchNorm2d):  # weight (no decay)
153
            g0.append(v.weight)
154
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)
155
            g1.append(v.weight)
156
157
    if opt.adam:
158
        optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum
159
    else:
160
        optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
161
162
    optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']})  # add g1 with weight_decay
163
    optimizer.add_param_group({'params': g2})  # add g2 (biases)
164
    LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
165
                f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
166
    del g0, g1, g2
167
168
    # Scheduler
169
    if opt.linear_lr:
170
        lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf']  # linear
171
    else:
172
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
173
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)
174
175
    # EMA
176
    ema = ModelEMA(model) if RANK in [-1, 0] else None
177
178
    # Resume
179
    start_epoch, best_fitness = 0, 0.0
180
    if pretrained:
181
        # Optimizer
182
        if ckpt['optimizer'] is not None:
183
            optimizer.load_state_dict(ckpt['optimizer'])
184
            best_fitness = ckpt['best_fitness']
185
186
        # EMA
187
        if ema and ckpt.get('ema'):
188
            ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
189
            ema.updates = ckpt['updates']
190
191
        # Epochs
192
        start_epoch = ckpt['epoch'] + 1
193
        if resume:
194
            assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
195
        if epochs < start_epoch:
196
            LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
197
            epochs += ckpt['epoch']  # finetune additional epochs
198
199
        del ckpt, csd
200
201
    # DP mode
202
    if cuda and RANK == -1 and torch.cuda.device_count() > 1:
203
        LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
204
                       'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
205
        model = torch.nn.DataParallel(model)
206
207
    # SyncBatchNorm
208
    if opt.sync_bn and cuda and RANK != -1:
209
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
210
        LOGGER.info('Using SyncBatchNorm()')
211
212
    # Trainloader
213
    train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
214
                                              hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
215
                                              workers=workers, image_weights=opt.image_weights, quad=opt.quad,
216
                                              prefix=colorstr('train: '), shuffle=True)
217
    mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max())  # max label class
218
    nb = len(train_loader)  # number of batches
219
    assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
220
221
    # Process 0
222
    if RANK in [-1, 0]:
223
        val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
224
                                       hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
225
                                       workers=workers, pad=0.5,
226
                                       prefix=colorstr('val: '))[0]
227
228
        if not resume:
229
            labels = np.concatenate(dataset.labels, 0)
230
            # c = torch.tensor(labels[:, 0])  # classes
231
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
232
            # model._initialize_biases(cf.to(device))
233
            if plots:
234
                plot_labels(labels, names, save_dir)
235
236
            # Anchors
237
            if not opt.noautoanchor:
238
                check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
239
            model.half().float()  # pre-reduce anchor precision
240
241
        callbacks.run('on_pretrain_routine_end')
242
243
    # DDP mode
244
    if cuda and RANK != -1:
245
        model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
246
247
    # Model attributes
248
    nl = de_parallel(model).model[-1].nl  # number of detection layers (to scale hyps)
249
    hyp['box'] *= 3 / nl  # scale to layers
250
    hyp['cls'] *= nc / 80 * 3 / nl  # scale to classes and layers
251
    hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
252
    hyp['label_smoothing'] = opt.label_smoothing
253
    model.nc = nc  # attach number of classes to model
254
    model.hyp = hyp  # attach hyperparameters to model
255
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
256
    model.names = names
257
258
    # Start training
259
    t0 = time.time()
260
    nw = max(round(hyp['warmup_epochs'] * nb), 1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
261
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
262
    last_opt_step = -1
263
    maps = np.zeros(nc)  # mAP per class
264
    results = (0, 0, 0, 0, 0, 0, 0)  # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
265
    scheduler.last_epoch = start_epoch - 1  # do not move
266
    scaler = amp.GradScaler(enabled=cuda)
267
    stopper = EarlyStopping(patience=opt.patience)
268
    compute_loss = ComputeLoss(model)  # init loss class
269
    LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
270
                f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
271
                f"Logging results to {colorstr('bold', save_dir)}\n"
272
                f'Starting training for {epochs} epochs...')
273
    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
274
        model.train()
275
276
        # Update image weights (optional, single-GPU only)
277
        if opt.image_weights:
278
            cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
279
            iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
280
            dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx
281
282
        # Update mosaic border (optional)
283
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
284
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders
285
286
        mloss = torch.zeros(3, device=device)  # mean losses
287
        if RANK != -1:
288
            train_loader.sampler.set_epoch(epoch)
289
        pbar = enumerate(train_loader)
290
        LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
291
        if RANK in [-1, 0]:
292
            pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')  # progress bar
293
        optimizer.zero_grad()
294
        for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
295
            ni = i + nb * epoch  # number integrated batches (since train start)
296
            imgs = imgs.to(device, non_blocking=True).float() / 255  # uint8 to float32, 0-255 to 0.0-1.0
297
298
            # Warmup
299
            if ni <= nw:
300
                xi = [0, nw]  # x interp
301
                # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
302
                accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
303
                for j, x in enumerate(optimizer.param_groups):
304
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
305
                    x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
306
                    if 'momentum' in x:
307
                        x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
308
309
            # Multi-scale
310
            if opt.multi_scale:
311
                sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs  # size
312
                sf = sz / max(imgs.shape[2:])  # scale factor
313
                if sf != 1:
314
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple)
315
                    imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
316
317
            # Forward
318
            with amp.autocast(enabled=cuda):
319
                pred = model(imgs)  # forward
320
                loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size
321
                if RANK != -1:
322
                    loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
323
                if opt.quad:
324
                    loss *= 4.
325
326
            # Backward
327
            scaler.scale(loss).backward()
328
329
            # Optimize
330
            if ni - last_opt_step >= accumulate:
331
                scaler.step(optimizer)  # optimizer.step
332
                scaler.update()
333
                optimizer.zero_grad()
334
                if ema:
335
                    ema.update(model)
336
                last_opt_step = ni
337
338
            # Log
339
            if RANK in [-1, 0]:
340
                mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
341
                mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
342
                pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
343
                    f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
344
                callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
345
            # end batch ------------------------------------------------------------------------------------------------
346
347
        # Scheduler
348
        lr = [x['lr'] for x in optimizer.param_groups]  # for loggers
349
        scheduler.step()
350
351
        if RANK in [-1, 0]:
352
            # mAP
353
            callbacks.run('on_train_epoch_end', epoch=epoch)
354
            ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
355
            final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
356
            if not noval or final_epoch:  # Calculate mAP
357
                results, maps, _ = val.run(data_dict,
358
                                           batch_size=batch_size // WORLD_SIZE * 2,
359
                                           imgsz=imgsz,
360
                                           model=ema.ema,
361
                                           single_cls=single_cls,
362
                                           dataloader=val_loader,
363
                                           save_dir=save_dir,
364
                                           plots=False,
365
                                           callbacks=callbacks,
366
                                           compute_loss=compute_loss)
367
368
            # Update best mAP
369
            fi = fitness(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
370
            if fi > best_fitness:
371
                best_fitness = fi
372
            log_vals = list(mloss) + list(results) + lr
373
            callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
374
375
            # Save model
376
            if (not nosave) or (final_epoch and not evolve):  # if save
377
                ckpt = {'epoch': epoch,
378
                        'best_fitness': best_fitness,
379
                        'model': deepcopy(de_parallel(model)).half(),
380
                        'ema': deepcopy(ema.ema).half(),
381
                        'updates': ema.updates,
382
                        'optimizer': optimizer.state_dict(),
383
                        'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
384
                        'date': datetime.now().isoformat()}
385
386
                # Save last, best and delete
387
                torch.save(ckpt, last)
388
                if best_fitness == fi:
389
                    torch.save(ckpt, best)
390
                if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
391
                    torch.save(ckpt, w / f'epoch{epoch}.pt')
392
                del ckpt
393
                callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
394
395
            # Stop Single-GPU
396
            if RANK == -1 and stopper(epoch=epoch, fitness=fi):
397
                break
398
399
            # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
400
            # stop = stopper(epoch=epoch, fitness=fi)
401
            # if RANK == 0:
402
            #    dist.broadcast_object_list([stop], 0)  # broadcast 'stop' to all ranks
403
404
        # Stop DPP
405
        # with torch_distributed_zero_first(RANK):
406
        # if stop:
407
        #    break  # must break all DDP ranks
408
409
        # end epoch ----------------------------------------------------------------------------------------------------
410
    # end training -----------------------------------------------------------------------------------------------------
411
    if RANK in [-1, 0]:
412
        LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
413
        for f in last, best:
414
            if f.exists():
415
                strip_optimizer(f)  # strip optimizers
416
                if f is best:
417
                    LOGGER.info(f'\nValidating {f}...')
418
                    results, _, _ = val.run(data_dict,
419
                                            batch_size=batch_size // WORLD_SIZE * 2,
420
                                            imgsz=imgsz,
421
                                            model=attempt_load(f, device).half(),
422
                                            iou_thres=0.65 if is_coco else 0.60,  # best pycocotools results at 0.65
423
                                            single_cls=single_cls,
424
                                            dataloader=val_loader,
425
                                            save_dir=save_dir,
426
                                            save_json=is_coco,
427
                                            verbose=True,
428
                                            plots=True,
429
                                            callbacks=callbacks,
430
                                            compute_loss=compute_loss)  # val best model with plots
431
                    if is_coco:
432
                        callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
433
434
        callbacks.run('on_train_end', last, best, plots, epoch, results)
435
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
436
437
    torch.cuda.empty_cache()
438
    return results
439
440
441
def parse_opt(known=False):
442
    parser = argparse.ArgumentParser()
443
    parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
444
    parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
445
    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
446
    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
447
    parser.add_argument('--epochs', type=int, default=300)
448
    parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
449
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
450
    parser.add_argument('--rect', action='store_true', help='rectangular training')
451
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
452
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
453
    parser.add_argument('--noval', action='store_true', help='only validate final epoch')
454
    parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
455
    parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
456
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
457
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
458
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
459
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
460
    parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
461
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
462
    parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
463
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
464
    parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
465
    parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
466
    parser.add_argument('--name', default='exp', help='save to project/name')
467
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
468
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
469
    parser.add_argument('--linear-lr', action='store_true', help='linear LR')
470
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
471
    parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
472
    parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
473
    parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
474
    parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
475
476
    # Weights & Biases arguments
477
    parser.add_argument('--entity', default=None, help='W&B: Entity')
478
    parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
479
    parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
480
    parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
481
482
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
483
    return opt
484
485
486
def main(opt, callbacks=Callbacks()):
487
    # Checks
488
    if RANK in [-1, 0]:
489
        print_args(FILE.stem, opt)
490
        check_git_status()
491
        check_requirements(exclude=['thop'])
492
493
    # Resume
494
    if opt.resume and not check_wandb_resume(opt) and not opt.evolve:  # resume an interrupted run
495
        ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run()  # specified or most recent path
496
        assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
497
        with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
498
            opt = argparse.Namespace(**yaml.safe_load(f))  # replace
499
        opt.cfg, opt.weights, opt.resume = '', ckpt, True  # reinstate
500
        LOGGER.info(f'Resuming training from {ckpt}')
501
    else:
502
        opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
503
            check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project)  # checks
504
        assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
505
        if opt.evolve:
506
            opt.project = str(ROOT / 'runs/evolve')
507
            opt.exist_ok, opt.resume = opt.resume, False  # pass resume to exist_ok and disable resume
508
        opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
509
510
    # DDP mode
511
    device = select_device(opt.device, batch_size=opt.batch_size)
512
    if LOCAL_RANK != -1:
513
        assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
514
        assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
515
        assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
516
        assert not opt.evolve, '--evolve argument is not compatible with DDP training'
517
        torch.cuda.set_device(LOCAL_RANK)
518
        device = torch.device('cuda', LOCAL_RANK)
519
        dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
520
521
    # Train
522
    if not opt.evolve:
523
        train(opt.hyp, opt, device, callbacks)
524
        if WORLD_SIZE > 1 and RANK == 0:
525
            LOGGER.info('Destroying process group... ')
526
            dist.destroy_process_group()
527
528
    # Evolve hyperparameters (optional)
529
    else:
530
        # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
531
        meta = {'lr0': (1, 1e-5, 1e-1),  # initial learning rate (SGD=1E-2, Adam=1E-3)
532
                'lrf': (1, 0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
533
                'momentum': (0.3, 0.6, 0.98),  # SGD momentum/Adam beta1
534
                'weight_decay': (1, 0.0, 0.001),  # optimizer weight decay
535
                'warmup_epochs': (1, 0.0, 5.0),  # warmup epochs (fractions ok)
536
                'warmup_momentum': (1, 0.0, 0.95),  # warmup initial momentum
537
                'warmup_bias_lr': (1, 0.0, 0.2),  # warmup initial bias lr
538
                'box': (1, 0.02, 0.2),  # box loss gain
539
                'cls': (1, 0.2, 4.0),  # cls loss gain
540
                'cls_pw': (1, 0.5, 2.0),  # cls BCELoss positive_weight
541
                'obj': (1, 0.2, 4.0),  # obj loss gain (scale with pixels)
542
                'obj_pw': (1, 0.5, 2.0),  # obj BCELoss positive_weight
543
                'iou_t': (0, 0.1, 0.7),  # IoU training threshold
544
                'anchor_t': (1, 2.0, 8.0),  # anchor-multiple threshold
545
                'anchors': (2, 2.0, 10.0),  # anchors per output grid (0 to ignore)
546
                'fl_gamma': (0, 0.0, 2.0),  # focal loss gamma (efficientDet default gamma=1.5)
547
                'hsv_h': (1, 0.0, 0.1),  # image HSV-Hue augmentation (fraction)
548
                'hsv_s': (1, 0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
549
                'hsv_v': (1, 0.0, 0.9),  # image HSV-Value augmentation (fraction)
550
                'degrees': (1, 0.0, 45.0),  # image rotation (+/- deg)
551
                'translate': (1, 0.0, 0.9),  # image translation (+/- fraction)
552
                'scale': (1, 0.0, 0.9),  # image scale (+/- gain)
553
                'shear': (1, 0.0, 10.0),  # image shear (+/- deg)
554
                'perspective': (0, 0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
555
                'flipud': (1, 0.0, 1.0),  # image flip up-down (probability)
556
                'fliplr': (0, 0.0, 1.0),  # image flip left-right (probability)
557
                'mosaic': (1, 0.0, 1.0),  # image mixup (probability)
558
                'mixup': (1, 0.0, 1.0),  # image mixup (probability)
559
                'copy_paste': (1, 0.0, 1.0)}  # segment copy-paste (probability)
560
561
        with open(opt.hyp, errors='ignore') as f:
562
            hyp = yaml.safe_load(f)  # load hyps dict
563
            if 'anchors' not in hyp:  # anchors commented in hyp.yaml
564
                hyp['anchors'] = 3
565
        opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir)  # only val/save final epoch
566
        # ei = [isinstance(x, (int, float)) for x in hyp.values()]  # evolvable indices
567
        evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
568
        if opt.bucket:
569
            os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}')  # download evolve.csv if exists
570
571
        for _ in range(opt.evolve):  # generations to evolve
572
            if evolve_csv.exists():  # if evolve.csv exists: select best hyps and mutate
573
                # Select parent(s)
574
                parent = 'single'  # parent selection method: 'single' or 'weighted'
575
                x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
576
                n = min(5, len(x))  # number of previous results to consider
577
                x = x[np.argsort(-fitness(x))][:n]  # top n mutations
578
                w = fitness(x) - fitness(x).min() + 1E-6  # weights (sum > 0)
579
                if parent == 'single' or len(x) == 1:
580
                    # x = x[random.randint(0, n - 1)]  # random selection
581
                    x = x[random.choices(range(n), weights=w)[0]]  # weighted selection
582
                elif parent == 'weighted':
583
                    x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination
584
585
                # Mutate
586
                mp, s = 0.8, 0.2  # mutation probability, sigma
587
                npr = np.random
588
                npr.seed(int(time.time()))
589
                g = np.array([meta[k][0] for k in hyp.keys()])  # gains 0-1
590
                ng = len(meta)
591
                v = np.ones(ng)
592
                while all(v == 1):  # mutate until a change occurs (prevent duplicates)
593
                    v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
594
                for i, k in enumerate(hyp.keys()):  # plt.hist(v.ravel(), 300)
595
                    hyp[k] = float(x[i + 7] * v[i])  # mutate
596
597
            # Constrain to limits
598
            for k, v in meta.items():
599
                hyp[k] = max(hyp[k], v[1])  # lower limit
600
                hyp[k] = min(hyp[k], v[2])  # upper limit
601
                hyp[k] = round(hyp[k], 5)  # significant digits
602
603
            # Train mutation
604
            results = train(hyp.copy(), opt, device, callbacks)
605
606
            # Write mutation results
607
            print_mutation(results, hyp.copy(), save_dir, opt.bucket)
608
609
        # Plot results
610
        plot_evolve(evolve_csv)
611
        LOGGER.info(f'Hyperparameter evolution finished\n'
612
                    f"Results saved to {colorstr('bold', save_dir)}\n"
613
                    f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
614
615
616
def run(**kwargs):
617
    # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
618
    opt = parse_opt(True)
619
    for k, v in kwargs.items():
620
        setattr(opt, k, v)
621
    main(opt)
622
623
624
if __name__ == "__main__":
625
    opt = parse_opt()
626
    main(opt)