Diff of /train.py [000000] .. [190ca4]

Switch to unified view

a b/train.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Train a YOLOv5 model on a custom dataset.
4
Models and datasets download automatically from the latest YOLOv5 release.
5
6
Usage - Single-GPU training:
7
    $ python train.py --data coco128.yaml --weights yolov5s.pt --img 640  # from pretrained (recommended)
8
    $ python train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640  # from scratch
9
10
Usage - Multi-GPU DDP training:
11
    $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 train.py --data coco128.yaml --weights yolov5s.pt --img 640 --device 0,1,2,3
12
13
Models:     https://github.com/ultralytics/yolov5/tree/master/models
14
Datasets:   https://github.com/ultralytics/yolov5/tree/master/data
15
Tutorial:   https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
16
"""
17
18
import argparse
19
import math
20
import os
21
import random
22
import subprocess
23
import sys
24
import time
25
from copy import deepcopy
26
from datetime import datetime, timedelta
27
from pathlib import Path
28
import torch.nn.functional as F
29
from utils.general import xywh2xyxy,get_fixed_xyxy
30
#from utils import custom_classifierCustomClassifier, train_and_evaluate, evaluate_classifier
31
32
33
try:
34
    import comet_ml  # must be imported before torch (if installed)
35
except ImportError:
36
    comet_ml = None
37
38
import numpy as np
39
import torch
40
import torch.distributed as dist
41
import torch.nn as nn
42
import yaml
43
from torch.optim import lr_scheduler
44
from tqdm import tqdm
45
from torchvision.ops import roi_align
46
from utils.general import get_object_level_feature_maps
47
48
FILE = Path(__file__).resolve()
49
ROOT = FILE.parents[0]  # YOLOv5 root directory
50
if str(ROOT) not in sys.path:
51
    sys.path.append(str(ROOT))  # add ROOT to PATH
52
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
53
54
import val as validate  # for end-of-epoch mAP
55
from models.experimental import attempt_load
56
from models.yolo import Model
57
from utils.autoanchor import check_anchors
58
from utils.autobatch import check_train_batch_size
59
from utils.callbacks import Callbacks
60
from utils.dataloaders import create_dataloader
61
from utils.downloads import attempt_download, is_url
62
from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
63
                           check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
64
                           get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
65
                           labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer,
66
                           yaml_save,plot_multi_channel_feature_map_with_boxes,xywh_to_xyxy)
67
from utils.loggers import LOGGERS, Loggers
68
from utils.loggers.comet.comet_utils import check_comet_resume
69
from utils.loss import ComputeLoss
70
from utils.metrics import fitness
71
from utils.plots import plot_evolve
72
from utils.custom_classifier import CustomClassifier, train_model_once
73
from utils.my_model import MyCNN,cell_training
74
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
75
                               smart_resume, torch_distributed_zero_first)
76
77
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
78
RANK = int(os.getenv('RANK', -1))
79
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
80
GIT_INFO = check_git_info()
81
82
83
84
def train(hyp, opt, device, callbacks):  # hyp is path/to/hyp.yaml or hyp dictionary
85
    save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
86
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
87
        opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
88
    callbacks.run('on_pretrain_routine_start')
89
90
91
    cell_attribute_model = MyCNN(num_classes=12, dropout_prob=0.5, in_channels=480).to(device)
92
    # cell_attribute_model.load_state_dict(torch.load('Attribute_model/best_weights_0.8056662588308221_51.pth'))
93
    #cell_attribute_model.train() 
94
    
95
    #step_size = 5
96
   # gamma = 0.01
97
   # scheduler_cell_model = lr_scheduler.StepLR(optimizer_cell_model, step_size=step_size, gamma=gamma)
98
99
    # Directories
100
    w = save_dir / 'weights'  # weights dir
101
    (w.parent if evolve else w).mkdir(parents=True, exist_ok=True)  # make dir
102
    last, best = w / 'last.pt', w / 'best.pt'
103
104
    # Hyperparameters
105
    if isinstance(hyp, str):
106
        with open(hyp, errors='ignore') as f:
107
            hyp = yaml.safe_load(f)  # load hyps dict
108
    LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
109
    opt.hyp = hyp.copy()  # for saving hyps to checkpoints
110
111
    # Save run settings
112
    if not evolve:
113
        yaml_save(save_dir / 'hyp.yaml', hyp)
114
        yaml_save(save_dir / 'opt.yaml', vars(opt))
115
116
    # Loggers
117
    data_dict = None
118
    if RANK in {-1, 0}:
119
        include_loggers = list(LOGGERS)
120
        if getattr(opt, 'ndjson_console', False):
121
            include_loggers.append('ndjson_console')
122
        if getattr(opt, 'ndjson_file', False):
123
            include_loggers.append('ndjson_file')
124
125
        loggers = Loggers(
126
            save_dir=save_dir,
127
            weights=weights,
128
            opt=opt,
129
            hyp=hyp,
130
            logger=LOGGER,
131
            include=tuple(include_loggers),
132
        )
133
134
        # Register actions
135
        for k in methods(loggers):
136
            callbacks.register_action(k, callback=getattr(loggers, k))
137
138
        # Process custom dataset artifact link
139
        data_dict = loggers.remote_dataset
140
        if resume:  # If resuming runs from remote artifact
141
            weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
142
143
    # Config
144
    plots = not evolve and not opt.noplots  # create plots
145
    cuda = device.type != 'cpu'
146
    init_seeds(opt.seed + 1 + RANK, deterministic=True)
147
    with torch_distributed_zero_first(LOCAL_RANK):
148
        data_dict = data_dict or check_dataset(data)  # check if None
149
    train_path, val_path = data_dict['train'], data_dict['val']
150
    nc = 1 if single_cls else int(data_dict['nc'])  # number of classes
151
    names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names']  # class names
152
    is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt')  # COCO dataset
153
154
    # Model
155
    check_suffix(weights, '.pt')  # check weights
156
    pretrained = weights.endswith('.pt')
157
    if pretrained:
158
        with torch_distributed_zero_first(LOCAL_RANK):
159
            weights = attempt_download(weights)  # download if not found locally
160
        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
161
        model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
162
        exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
163
        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
164
        csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect
165
        model.load_state_dict(csd, strict=False)  # load
166
        LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  # report
167
    else:
168
        model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
169
    amp = check_amp(model)  # check AMP
170
171
    # Freeze
172
    freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]  # layers to freeze
173
    for k, v in model.named_parameters():
174
        v.requires_grad = True  # train all layers
175
        # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
176
        if any(x in k for x in freeze):
177
            LOGGER.info(f'freezing {k}')
178
            v.requires_grad = False
179
180
    # Image size
181
    gs = max(int(model.stride.max()), 32)  # grid size (max stride)
182
    imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2)  # verify imgsz is gs-multiple
183
184
    # Batch size
185
    if RANK == -1 and batch_size == -1:  # single-GPU only, estimate best batch size
186
        batch_size = check_train_batch_size(model, imgsz, amp)
187
        loggers.on_params_update({'batch_size': batch_size})
188
189
    # Optimizer
190
    nbs = 64  # nominal batch size
191
    accumulate = max(round(nbs / batch_size), 1)  # accumulate loss before optimizing
192
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
193
    optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
194
    #optimizer_cell_model = torch.optim.Adam(cell_attribute_model.parameters(), opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
195
    optimizer_cell_model = torch.optim.SGD(cell_attribute_model.parameters(), lr=hyp['lr0'],momentum= hyp['momentum'], weight_decay=hyp['weight_decay'])
196
197
    # Scheduler
198
    if opt.cos_lr:
199
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
200
    else:
201
        lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']  # linear
202
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
203
    scheduler_cell_model = lr_scheduler.LambdaLR(optimizer_cell_model, lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)
204
205
    
206
207
    # EMA
208
    ema = ModelEMA(model) if RANK in {-1, 0} else None
209
210
    # Resume
211
    best_fitness, start_epoch = 0.0, 0
212
    if pretrained:
213
        if resume:
214
            best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
215
        del ckpt, csd
216
217
    # DP mode
218
    if cuda and RANK == -1 and torch.cuda.device_count() > 1:
219
        LOGGER.warning(
220
            'WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
221
            'See Multi-GPU Tutorial at https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training to get started.'
222
        )
223
        model = torch.nn.DataParallel(model)
224
225
    # SyncBatchNorm
226
    if opt.sync_bn and cuda and RANK != -1:
227
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
228
        LOGGER.info('Using SyncBatchNorm()')
229
230
    # Trainloader
231
    train_loader, dataset = create_dataloader(train_path,
232
                                              imgsz,
233
                                              batch_size // WORLD_SIZE,
234
                                              gs,
235
                                              single_cls,
236
                                              hyp=hyp,
237
                                              augment=True,
238
                                              cache=None if opt.cache == 'val' else opt.cache,
239
                                              rect=opt.rect,
240
                                              rank=LOCAL_RANK,
241
                                              workers=workers,
242
                                              image_weights=opt.image_weights,
243
                                              quad=opt.quad,
244
                                              prefix=colorstr('train: '),
245
                                              shuffle=True,
246
                                              seed=opt.seed)
247
    labels = np.concatenate(dataset.labels, 0)
248
    mlc = int(labels[:, 0].max())  # max label class
249
    assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
250
251
    # Process 0
252
    if RANK in {-1, 0}:
253
        val_loader = create_dataloader(val_path,
254
                                       imgsz,
255
                                       batch_size // WORLD_SIZE * 2,
256
                                       gs,
257
                                       single_cls,
258
                                       hyp=hyp,
259
                                       cache=None if noval else opt.cache,
260
                                       rect=True,
261
                                       rank=-1,
262
                                       workers=workers * 2,
263
                                       pad=0.5,
264
                                       prefix=colorstr('val: '))[0]
265
266
        if not resume:
267
            if not opt.noautoanchor:
268
                check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)  # run AutoAnchor
269
            model.half().float()  # pre-reduce anchor precision
270
271
        callbacks.run('on_pretrain_routine_end', labels, names)
272
273
    # DDP mode
274
    if cuda and RANK != -1:
275
        model = smart_DDP(model)
276
277
    # Model attributes
278
    nl = de_parallel(model).model[-1].nl  # number of detection layers (to scale hyps)
279
    hyp['box'] *= 3 / nl  # scale to layers
280
    hyp['cls'] *= nc / 80 * 3 / nl  # scale to classes and layers
281
    hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
282
    hyp['label_smoothing'] = opt.label_smoothing
283
    model.nc = nc  # attach number of classes to model
284
    model.hyp = hyp  # attach hyperparameters to model
285
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
286
    model.names = names
287
288
    # Start training
289
    t0 = time.time()
290
    nb = len(train_loader)  # number of batches
291
    nw = max(round(hyp['warmup_epochs'] * nb), 100)  # number of warmup iterations, max(3 epochs, 100 iterations)
292
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
293
    last_opt_step = -1
294
    maps = np.zeros(nc)  # mAP per class
295
    results = (0, 0, 0, 0, 0, 0, 0)  # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
296
    scheduler.last_epoch = start_epoch - 1  # do not move
297
    scaler = torch.cuda.amp.GradScaler(enabled=amp)
298
    stopper, stop = EarlyStopping(patience=opt.patience), False
299
    compute_loss = ComputeLoss(model)  # init loss class
300
    callbacks.run('on_train_start')
301
    LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
302
                f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
303
                f"Logging results to {colorstr('bold', save_dir)}\n"
304
                f'Starting training for {epochs} epochs...')
305
    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
306
        callbacks.run('on_train_epoch_start')
307
        model.train()
308
        cell_attribute_model.train() 
309
310
        # Update image weights (optional, single-GPU only)
311
        if opt.image_weights:
312
            cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
313
            iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
314
            dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx
315
316
        # Update mosaic border (optional)
317
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
318
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders
319
320
        mloss = torch.zeros(3, device=device)  # mean losses
321
        if RANK != -1:
322
            train_loader.sampler.set_epoch(epoch)
323
        pbar = enumerate(train_loader)
324
        LOGGER.info(('\n' + '%11s' * 8) % ('Epoch', 'GPU_mem', 'box_loss', 'obj_loss', 'cls_loss', 'attr_loss', 'Instances', 'Size'))
325
        if RANK in {-1, 0}:
326
            pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT)  # progress bar
327
        optimizer.zero_grad()
328
        avg_attribute_loss= 0
329
        length_of_data=0
330
        for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
331
            
332
            callbacks.run('on_train_batch_start')
333
            ni = i + nb * epoch  # number integrated batches (since train start)
334
            imgs = imgs.to(device, non_blocking=True).float() / 255  # uint8 to float32, 0-255 to 0.0-1.0
335
336
            # Warmup
337
            if ni <= nw:
338
                xi = [0, nw]  # x interp
339
                # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
340
                accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
341
                for j, x in enumerate(optimizer.param_groups):
342
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
343
                    x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
344
                    if 'momentum' in x:
345
                        x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
346
347
            # Multi-scale
348
            if opt.multi_scale:
349
                sz = random.randrange(int(imgsz * 0.5), int(imgsz * 1.5) + gs) // gs * gs  # size
350
                sf = sz / max(imgs.shape[2:])  # scale factor
351
                if sf != 1:
352
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple)
353
                    imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
354
355
            # Forward
356
            with torch.cuda.amp.autocast(amp):
357
                pred,int_feat = model(imgs)  # forward
358
                
359
                #batch_obj = 0
360
                Num_targets = len(targets)
361
                pooled_feature_map_batch = []
362
                optimizer_cell_model.zero_grad()  
363
364
                for i in range(Num_targets):
365
                    img_num = int(targets[i,0].item())
366
367
                    p2_feature_map =int_feat[0][img_num] # imgs[img_num] 
368
                    p3_feature_map = int_feat[1][img_num]
369
370
                    x_center = targets[i, 2]
371
                    y_center = targets[i, 3]
372
                    width = targets[i, 4]
373
                    height = targets[i, 5]
374
                    bb = [round(x_center.item(),4), round(y_center.item(),4), round(width.item(),4), round(height.item(),4)]
375
                    p2_feature_shape_tensor = torch.tensor([int_feat[0][img_num].shape[1], int_feat[0][img_num].shape[2],int_feat[0][img_num].shape[1],int_feat[0][img_num].shape[2]])                        # reduce_channels_layer = torch.nn.Conv2d(1280, 250, kernel_size=1).to(device)
376
                    p3_feature_shape_tensor = torch.tensor([int_feat[1][img_num].shape[1], int_feat[1][img_num].shape[2],int_feat[1][img_num].shape[1],int_feat[1][img_num].shape[2]])                        # reduce_channels_layer = torch.nn.Conv2d(1280, 250, kernel_size=1).to(device)
377
                        # reduce_channels_layer = torch.nn.Conv2d(1280, 250, kernel_size=1).to(device)
378
379
                    p2_normalized_xyxy = xywh_to_xyxy(bb)*p2_feature_shape_tensor #imgs.shape[2]
380
                    p3_normalized_xyxy = xywh_to_xyxy(bb)*p3_feature_shape_tensor #imgs.shape[2]
381
382
383
                    p2_x_min, p2_y_min, p2_x_max, p2_y_max = get_fixed_xyxy(p2_normalized_xyxy,p2_feature_map)
384
                    p3_x_min, p3_y_min, p3_x_max, p3_y_max = get_fixed_xyxy(p3_normalized_xyxy,p3_feature_map)
385
    
386
                    batch_index = torch.tensor([0], dtype=torch.float32).to(device)
387
388
                    p2_roi = torch.tensor([p2_x_min, p2_y_min, p2_x_max, p2_y_max], device=device).float() 
389
                    p3_roi = torch.tensor([p3_x_min, p3_y_min, p3_x_max, p3_y_max], device=device).float() 
390
391
392
                    # Concatenate the batch index to the bounding box coordinates
393
                    p2_roi_with_batch_index = torch.cat([batch_index, p2_roi])
394
                    p3_roi_with_batch_index = torch.cat([batch_index, p3_roi])
395
396
                    # relevant_feature_map = p3_feature_map.unsqueeze(0)[:, :, y_min:y_max, x_min:x_max]
397
                    p2_resized_object = roi_align(p2_feature_map.unsqueeze(0), p2_roi_with_batch_index.unsqueeze(0).to(device), output_size=(24, 30))
398
                    p3_resized_object = roi_align(p3_feature_map.unsqueeze(0), p3_roi_with_batch_index.unsqueeze(0).to(device), output_size=(24, 30))
399
                    concat_box = torch.cat([p2_resized_object,p3_resized_object],dim=1)
400
401
                    
402
                    pooled_feature_map_batch.append(concat_box)
403
                cell_attribute_loss= cell_training(cell_attribute_model,pooled_feature_map_batch, targets[:,6:13].to(device))
404
                    # del concatenated_features
405
                cell_attribute_loss.backward(retain_graph=True)
406
                optimizer_cell_model.step()
407
                
408
                avg_attribute_loss+=cell_attribute_loss.item()
409
                length_of_data+=1
410
   
411
412
                loss, loss_items = compute_loss(pred, targets[:,0:6].to(device))  # loss scaled by batch_size I changed here
413
                if RANK != -1:
414
                    loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
415
                if opt.quad:
416
                    loss *= 4.
417
418
            # Backward
419
            scaler.scale(loss).backward()
420
421
            # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
422
            if ni - last_opt_step >= accumulate:
423
                scaler.unscale_(optimizer)  # unscale gradients
424
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)  # clip gradients
425
                scaler.step(optimizer)  # optimizer.step
426
                scaler.update()
427
                optimizer.zero_grad()
428
                if ema:
429
                    ema.update(model)
430
                last_opt_step = ni
431
432
            # Log
433
            if RANK in {-1, 0}:
434
                mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
435
                mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
436
                avg_attr_loss = avg_attribute_loss / length_of_data  # Calculate the average attribute loss
437
438
                pbar.set_description(('%11s' * 2 + '%11.4g' * 6) %
439
                                     (f'{epoch}/{epochs - 1}', mem, *mloss,avg_attr_loss, targets.shape[0], imgs.shape[-1]))
440
                callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths, list(mloss))
441
                if callbacks.stop_training:
442
                    return
443
            # end batch ------------------------------------------------------------------------------------------------
444
        # print("Attribute_average_loss=   ",avg_attribute_loss/length_of_data)
445
        # Scheduler
446
        lr = [x['lr'] for x in optimizer.param_groups]  # for loggers
447
        scheduler.step()
448
        scheduler_cell_model.step()
449
450
451
        if RANK in {-1, 0}:
452
        #   if epoch > 50:
453
            # mAP
454
            callbacks.run('on_train_epoch_end', epoch=epoch)
455
            ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
456
            final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
457
            if not noval or final_epoch:  # Calculate mAP
458
                results, maps, _ = validate.run(data_dict,cell_attribute_model,
459
                                                batch_size=1,# batch_size // WORLD_SIZE * 2,
460
                                                imgsz=imgsz,
461
                                                half=amp,
462
                                                model=ema.ema,
463
                                                single_cls=single_cls,
464
                                                dataloader=val_loader,
465
                                                save_dir=save_dir,
466
                                                plots=False,
467
                                                callbacks=callbacks,
468
                                                compute_loss=compute_loss
469
                                                )
470
471
            # Update best mAP
472
            fi = fitness(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
473
            stop = stopper(epoch=epoch, fitness=fi)  # early stop check
474
            if fi > best_fitness:
475
                best_fitness = fi
476
            log_vals = list(mloss) + list(results) + lr
477
            callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
478
479
            # Save model
480
            if (not nosave) or (final_epoch and not evolve):  # if save
481
                ckpt = {
482
                    'epoch': epoch,
483
                    'best_fitness': best_fitness,
484
                    'model': deepcopy(de_parallel(model)).half(),
485
                    'ema': deepcopy(ema.ema).half(),
486
                    'updates': ema.updates,
487
                    'optimizer': optimizer.state_dict(),
488
                    'opt': vars(opt),
489
                    'git': GIT_INFO,  # {remote, branch, commit} if a git repo
490
                    'date': datetime.now().isoformat()}
491
492
                # Save last, best and delete
493
                torch.save(ckpt, last)
494
                if best_fitness == fi:
495
                    torch.save(ckpt, best)
496
                if opt.save_period > 0 and epoch % opt.save_period == 0:
497
                    torch.save(ckpt, w / f'epoch{epoch}.pt')
498
                del ckpt
499
                callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
500
501
        # EarlyStopping
502
        if RANK != -1:  # if DDP training
503
            broadcast_list = [stop if RANK == 0 else None]
504
            dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
505
            if RANK != 0:
506
                stop = broadcast_list[0]
507
        if stop:
508
            break  # must break all DDP ranks
509
510
        # end epoch ----------------------------------------------------------------------------------------------------
511
    # end training -----------------------------------------------------------------------------------------------------
512
    if RANK in {-1, 0}:
513
        LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
514
        for f in last, best:
515
            if f.exists():
516
                strip_optimizer(f)  # strip optimizers
517
                if f is best:
518
                    LOGGER.info(f'\nValidating {f}...')
519
                    results, _, _ = validate.run(
520
                        data_dict, cell_attribute_model,
521
                        batch_size=batch_size // WORLD_SIZE * 2,
522
                        imgsz=imgsz,
523
                        model=attempt_load(f, device).half(),
524
                        iou_thres=0.65 if is_coco else 0.60,  # best pycocotools at iou 0.65
525
                        single_cls=single_cls,
526
                        dataloader=val_loader,
527
                        save_dir=save_dir,
528
                        save_json=is_coco,
529
                        verbose=True,
530
                        plots=plots,
531
                        callbacks=callbacks,
532
                        compute_loss=compute_loss)  # val best model with plots
533
                    if is_coco:
534
                        callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
535
536
        callbacks.run('on_train_end', last, best, epoch, results)
537
538
    torch.cuda.empty_cache()
539
    return results
540
541
542
def parse_opt(known=False):
543
    parser = argparse.ArgumentParser()
544
    parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
545
    parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
546
    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
547
    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
548
    parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
549
    parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
550
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
551
    parser.add_argument('--rect', action='store_true', help='rectangular training')
552
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
553
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
554
    parser.add_argument('--noval', action='store_true', help='only validate final epoch')
555
    parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
556
    parser.add_argument('--noplots', action='store_true', help='save no plot files')
557
    parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
558
    parser.add_argument('--evolve_population',
559
                        type=str,
560
                        default=ROOT / 'data/hyps',
561
                        help='location for loading population')
562
    parser.add_argument('--resume_evolve', type=str, default=None, help='resume evolve from last generation')
563
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
564
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
565
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
566
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
567
    parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
568
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
569
    parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
570
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
571
    parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
572
    parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
573
    parser.add_argument('--name', default='exp', help='save to project/name')
574
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
575
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
576
    parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
577
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
578
    parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
579
    parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
580
    parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
581
    parser.add_argument('--seed', type=int, default=0, help='Global training seed')
582
    parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
583
584
    # Logger arguments
585
    parser.add_argument('--entity', default=None, help='Entity')
586
    parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='Upload data, "val" option')
587
    parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval')
588
    parser.add_argument('--artifact_alias', type=str, default='latest', help='Version of dataset artifact to use')
589
590
    # NDJSON logging
591
    parser.add_argument('--ndjson-console', action='store_true', help='Log ndjson to console')
592
    parser.add_argument('--ndjson-file', action='store_true', help='Log ndjson to file')
593
594
    return parser.parse_known_args()[0] if known else parser.parse_args()
595
596
597
def main(opt, callbacks=Callbacks()):
598
    # Checks
599
    if RANK in {-1, 0}:
600
        print_args(vars(opt))
601
        check_git_status()
602
        check_requirements(ROOT / 'requirements.txt')
603
604
    # Resume (from specified or most recent last.pt)
605
    if opt.resume and not check_comet_resume(opt) and not opt.evolve:
606
        last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
607
        opt_yaml = last.parent.parent / 'opt.yaml'  # train options yaml
608
        opt_data = opt.data  # original dataset
609
        if opt_yaml.is_file():
610
            with open(opt_yaml, errors='ignore') as f:
611
                d = yaml.safe_load(f)
612
        else:
613
            d = torch.load(last, map_location='cpu')['opt']
614
        opt = argparse.Namespace(**d)  # replace
615
        opt.cfg, opt.weights, opt.resume = '', str(last), True  # reinstate
616
        if is_url(opt_data):
617
            opt.data = check_file(opt_data)  # avoid HUB resume auth timeout
618
    else:
619
        opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
620
            check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project)  # checks
621
        assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
622
        if opt.evolve:
623
            if opt.project == str(ROOT / 'runs/train'):  # if default project name, rename to runs/evolve
624
                opt.project = str(ROOT / 'runs/evolve')
625
            opt.exist_ok, opt.resume = opt.resume, False  # pass resume to exist_ok and disable resume
626
        if opt.name == 'cfg':
627
            opt.name = Path(opt.cfg).stem  # use model.yaml as name
628
        opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
629
630
    # DDP mode
631
    device = select_device(opt.device, batch_size=opt.batch_size)
632
    if LOCAL_RANK != -1:
633
        msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
634
        assert not opt.image_weights, f'--image-weights {msg}'
635
        assert not opt.evolve, f'--evolve {msg}'
636
        assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
637
        assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
638
        assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
639
        torch.cuda.set_device(LOCAL_RANK)
640
        device = torch.device('cuda', LOCAL_RANK)
641
        dist.init_process_group(backend='nccl' if dist.is_nccl_available() else 'gloo',
642
                                timeout=timedelta(seconds=10800))
643
644
    # Train
645
    if not opt.evolve:
646
        train(opt.hyp, opt, device, callbacks)
647
648
    # Evolve hyperparameters (optional)
649
    else:
650
        # Hyperparameter evolution metadata (including this hyperparameter True-False, lower_limit, upper_limit)
651
        meta = {
652
            'lr0': (False, 1e-5, 1e-1),  # initial learning rate (SGD=1E-2, Adam=1E-3)
653
            'lrf': (False, 0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
654
            'momentum': (False, 0.6, 0.98),  # SGD momentum/Adam beta1
655
            'weight_decay': (False, 0.0, 0.001),  # optimizer weight decay
656
            'warmup_epochs': (False, 0.0, 5.0),  # warmup epochs (fractions ok)
657
            'warmup_momentum': (False, 0.0, 0.95),  # warmup initial momentum
658
            'warmup_bias_lr': (False, 0.0, 0.2),  # warmup initial bias lr
659
            'box': (False, 0.02, 0.2),  # box loss gain
660
            'cls': (False, 0.2, 4.0),  # cls loss gain
661
            'cls_pw': (False, 0.5, 2.0),  # cls BCELoss positive_weight
662
            'obj': (False, 0.2, 4.0),  # obj loss gain (scale with pixels)
663
            'obj_pw': (False, 0.5, 2.0),  # obj BCELoss positive_weight
664
            'iou_t': (False, 0.1, 0.7),  # IoU training threshold
665
            'anchor_t': (False, 2.0, 8.0),  # anchor-multiple threshold
666
            'anchors': (False, 2.0, 10.0),  # anchors per output grid (0 to ignore)
667
            'fl_gamma': (False, 0.0, 2.0),  # focal loss gamma (efficientDet default gamma=1.5)
668
            'hsv_h': (True, 0.0, 0.1),  # image HSV-Hue augmentation (fraction)
669
            'hsv_s': (True, 0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
670
            'hsv_v': (True, 0.0, 0.9),  # image HSV-Value augmentation (fraction)
671
            'degrees': (True, 0.0, 45.0),  # image rotation (+/- deg)
672
            'translate': (True, 0.0, 0.9),  # image translation (+/- fraction)
673
            'scale': (True, 0.0, 0.9),  # image scale (+/- gain)
674
            'shear': (True, 0.0, 10.0),  # image shear (+/- deg)
675
            'perspective': (True, 0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
676
            'flipud': (True, 0.0, 1.0),  # image flip up-down (probability)
677
            'fliplr': (True, 0.0, 1.0),  # image flip left-right (probability)
678
            'mosaic': (True, 0.0, 1.0),  # image mixup (probability)
679
            'mixup': (True, 0.0, 1.0),  # image mixup (probability)
680
            'copy_paste': (True, 0.0, 1.0)}  # segment copy-paste (probability)
681
682
        # GA configs
683
        pop_size = 50
684
        mutation_rate_min = 0.01
685
        mutation_rate_max = 0.5
686
        crossover_rate_min = 0.5
687
        crossover_rate_max = 1
688
        min_elite_size = 2
689
        max_elite_size = 5
690
        tournament_size_min = 2
691
        tournament_size_max = 10
692
693
        with open(opt.hyp, errors='ignore') as f:
694
            hyp = yaml.safe_load(f)  # load hyps dict
695
            if 'anchors' not in hyp:  # anchors commented in hyp.yaml
696
                hyp['anchors'] = 3
697
        if opt.noautoanchor:
698
            del hyp['anchors'], meta['anchors']
699
        opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir)  # only val/save final epoch
700
        # ei = [isinstance(x, (int, float)) for x in hyp.values()]  # evolvable indices
701
        evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
702
        if opt.bucket:
703
            # download evolve.csv if exists
704
            subprocess.run([
705
                'gsutil',
706
                'cp',
707
                f'gs://{opt.bucket}/evolve.csv',
708
                str(evolve_csv), ])
709
710
        # Delete the items in meta dictionary whose first value is False
711
        del_ = []
712
        for item in meta.keys():
713
            if meta[item][0] is False:
714
                del_.append(item)
715
        hyp_GA = hyp.copy()  # Make a copy of hyp dictionary
716
        for item in del_:
717
            del meta[item]  # Remove the item from meta dictionary
718
            del hyp_GA[item]  # Remove the item from hyp_GA dictionary
719
720
        # Set lower_limit and upper_limit arrays to hold the search space boundaries
721
        lower_limit = np.array([meta[k][1] for k in hyp_GA.keys()])
722
        upper_limit = np.array([meta[k][2] for k in hyp_GA.keys()])
723
724
        # Create gene_ranges list to hold the range of values for each gene in the population
725
        gene_ranges = []
726
        for i in range(len(upper_limit)):
727
            gene_ranges.append((lower_limit[i], upper_limit[i]))
728
729
        # Initialize the population with initial_values or random values
730
        initial_values = []
731
732
        # If resuming evolution from a previous checkpoint
733
        if opt.resume_evolve is not None:
734
            assert os.path.isfile(ROOT / opt.resume_evolve), 'evolve population path is wrong!'
735
            with open(ROOT / opt.resume_evolve, errors='ignore') as f:
736
                evolve_population = yaml.safe_load(f)
737
                for value in evolve_population.values():
738
                    value = np.array([value[k] for k in hyp_GA.keys()])
739
                    initial_values.append(list(value))
740
741
        # If not resuming from a previous checkpoint, generate initial values from .yaml files in opt.evolve_population
742
        else:
743
            yaml_files = [f for f in os.listdir(opt.evolve_population) if f.endswith('.yaml')]
744
            for file_name in yaml_files:
745
                with open(os.path.join(opt.evolve_population, file_name)) as yaml_file:
746
                    value = yaml.safe_load(yaml_file)
747
                    value = np.array([value[k] for k in hyp_GA.keys()])
748
                    initial_values.append(list(value))
749
750
        # Generate random values within the search space for the rest of the population
751
        if (initial_values is None):
752
            population = [generate_individual(gene_ranges, len(hyp_GA)) for i in range(pop_size)]
753
        else:
754
            if (pop_size > 1):
755
                population = [
756
                    generate_individual(gene_ranges, len(hyp_GA)) for i in range(pop_size - len(initial_values))]
757
                for initial_value in initial_values:
758
                    population = [initial_value] + population
759
760
        # Run the genetic algorithm for a fixed number of generations
761
        list_keys = list(hyp_GA.keys())
762
        for generation in range(opt.evolve):
763
            if (generation >= 1):
764
                save_dict = {}
765
                for i in range(len(population)):
766
                    little_dict = {}
767
                    for j in range(len(population[i])):
768
                        little_dict[list_keys[j]] = float(population[i][j])
769
                    save_dict['gen' + str(generation) + 'number' + str(i)] = little_dict
770
771
                with open(save_dir / 'evolve_population.yaml', 'w') as outfile:
772
                    yaml.dump(save_dict, outfile, default_flow_style=False)
773
774
            # Adaptive elite size
775
            elite_size = min_elite_size + int((max_elite_size - min_elite_size) * (generation / opt.evolve))
776
            # Evaluate the fitness of each individual in the population
777
            fitness_scores = []
778
            for individual in population:
779
                for key, value in zip(hyp_GA.keys(), individual):
780
                    hyp_GA[key] = value
781
                hyp.update(hyp_GA)
782
                results = train(hyp.copy(), opt, device, callbacks)
783
                callbacks = Callbacks()
784
                # Write mutation results
785
                keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
786
                        'val/box_loss', 'val/obj_loss', 'val/cls_loss')
787
                print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)
788
                fitness_scores.append(results[2])
789
790
            # Select the fittest individuals for reproduction using adaptive tournament selection
791
            selected_indices = []
792
            for i in range(pop_size - elite_size):
793
                # Adaptive tournament size
794
                tournament_size = max(max(2, tournament_size_min),
795
                                      int(min(tournament_size_max, pop_size) - (generation / (opt.evolve / 10))))
796
                # Perform tournament selection to choose the best individual
797
                tournament_indices = random.sample(range(pop_size), tournament_size)
798
                tournament_fitness = [fitness_scores[j] for j in tournament_indices]
799
                winner_index = tournament_indices[tournament_fitness.index(max(tournament_fitness))]
800
                selected_indices.append(winner_index)
801
802
            # Add the elite individuals to the selected indices
803
            elite_indices = [i for i in range(pop_size) if fitness_scores[i] in sorted(fitness_scores)[-elite_size:]]
804
            selected_indices.extend(elite_indices)
805
            # Create the next generation through crossover and mutation
806
            next_generation = []
807
            for i in range(pop_size):
808
                parent1_index = selected_indices[random.randint(0, pop_size - 1)]
809
                parent2_index = selected_indices[random.randint(0, pop_size - 1)]
810
                # Adaptive crossover rate
811
                crossover_rate = max(crossover_rate_min,
812
                                     min(crossover_rate_max, crossover_rate_max - (generation / opt.evolve)))
813
                if random.uniform(0, 1) < crossover_rate:
814
                    crossover_point = random.randint(1, len(hyp_GA) - 1)
815
                    child = population[parent1_index][:crossover_point] + population[parent2_index][crossover_point:]
816
                else:
817
                    child = population[parent1_index]
818
                # Adaptive mutation rate
819
                mutation_rate = max(mutation_rate_min,
820
                                    min(mutation_rate_max, mutation_rate_max - (generation / opt.evolve)))
821
                for j in range(len(hyp_GA)):
822
                    if random.uniform(0, 1) < mutation_rate:
823
                        child[j] += random.uniform(-0.1, 0.1)
824
                        child[j] = min(max(child[j], gene_ranges[j][0]), gene_ranges[j][1])
825
                next_generation.append(child)
826
            # Replace the old population with the new generation
827
            population = next_generation
828
        # Print the best solution found
829
        best_index = fitness_scores.index(max(fitness_scores))
830
        best_individual = population[best_index]
831
        print('Best solution found:', best_individual)
832
        # Plot results
833
        plot_evolve(evolve_csv)
834
        LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
835
                    f"Results saved to {colorstr('bold', save_dir)}\n"
836
                    f'Usage example: $ python train.py --hyp {evolve_yaml}')
837
838
839
def generate_individual(input_ranges, individual_length):
840
    individual = []
841
    for i in range(individual_length):
842
        lower_bound, upper_bound = input_ranges[i]
843
        individual.append(random.uniform(lower_bound, upper_bound))
844
    return individual
845
846
847
def run(**kwargs):
848
    # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
849
    opt = parse_opt(True)
850
    for k, v in kwargs.items():
851
        setattr(opt, k, v)
852
    main(opt)
853
    return opt
854
855
856
if __name__ == '__main__':
857
    opt = parse_opt()
858
    main(opt)