Diff of /classify/train.py [000000] .. [190ca4]

Switch to unified view

a b/classify/train.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Train a YOLOv5 classifier model on a classification dataset
4
5
Usage - Single-GPU training:
6
    $ python classify/train.py --model yolov5s-cls.pt --data imagenette160 --epochs 5 --img 224
7
8
Usage - Multi-GPU DDP training:
9
    $ python -m torch.distributed.run --nproc_per_node 4 --master_port 2022 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3
10
11
Datasets:           --data mnist, fashion-mnist, cifar10, cifar100, imagenette, imagewoof, imagenet, or 'path/to/data'
12
YOLOv5-cls models:  --model yolov5n-cls.pt, yolov5s-cls.pt, yolov5m-cls.pt, yolov5l-cls.pt, yolov5x-cls.pt
13
Torchvision models: --model resnet50, efficientnet_b0, etc. See https://pytorch.org/vision/stable/models.html
14
"""
15
16
import argparse
17
import os
18
import subprocess
19
import sys
20
import time
21
from copy import deepcopy
22
from datetime import datetime
23
from pathlib import Path
24
25
import torch
26
import torch.distributed as dist
27
import torch.hub as hub
28
import torch.optim.lr_scheduler as lr_scheduler
29
import torchvision
30
from torch.cuda import amp
31
from tqdm import tqdm
32
33
FILE = Path(__file__).resolve()
34
ROOT = FILE.parents[1]  # YOLOv5 root directory
35
if str(ROOT) not in sys.path:
36
    sys.path.append(str(ROOT))  # add ROOT to PATH
37
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
38
39
from classify import val as validate
40
from models.experimental import attempt_load
41
from models.yolo import ClassificationModel, DetectionModel
42
from utils.dataloaders import create_classification_dataloader
43
from utils.general import (DATASETS_DIR, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_info, check_git_status,
44
                           check_requirements, colorstr, download, increment_path, init_seeds, print_args, yaml_save)
45
from utils.loggers import GenericLogger
46
from utils.plots import imshow_cls
47
from utils.torch_utils import (ModelEMA, de_parallel, model_info, reshape_classifier_output, select_device, smart_DDP,
48
                               smart_optimizer, smartCrossEntropyLoss, torch_distributed_zero_first)
49
50
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
51
RANK = int(os.getenv('RANK', -1))
52
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
53
GIT_INFO = check_git_info()
54
55
56
def train(opt, device):
57
    init_seeds(opt.seed + 1 + RANK, deterministic=True)
58
    save_dir, data, bs, epochs, nw, imgsz, pretrained = \
59
        opt.save_dir, Path(opt.data), opt.batch_size, opt.epochs, min(os.cpu_count() - 1, opt.workers), \
60
        opt.imgsz, str(opt.pretrained).lower() == 'true'
61
    cuda = device.type != 'cpu'
62
63
    # Directories
64
    wdir = save_dir / 'weights'
65
    wdir.mkdir(parents=True, exist_ok=True)  # make dir
66
    last, best = wdir / 'last.pt', wdir / 'best.pt'
67
68
    # Save run settings
69
    yaml_save(save_dir / 'opt.yaml', vars(opt))
70
71
    # Logger
72
    logger = GenericLogger(opt=opt, console_logger=LOGGER) if RANK in {-1, 0} else None
73
74
    # Download Dataset
75
    with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
76
        data_dir = data if data.is_dir() else (DATASETS_DIR / data)
77
        if not data_dir.is_dir():
78
            LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
79
            t = time.time()
80
            if str(data) == 'imagenet':
81
                subprocess.run(['bash', str(ROOT / 'data/scripts/get_imagenet.sh')], shell=True, check=True)
82
            else:
83
                url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{data}.zip'
84
                download(url, dir=data_dir.parent)
85
            s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
86
            LOGGER.info(s)
87
88
    # Dataloaders
89
    nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()])  # number of classes
90
    trainloader = create_classification_dataloader(path=data_dir / 'train',
91
                                                   imgsz=imgsz,
92
                                                   batch_size=bs // WORLD_SIZE,
93
                                                   augment=True,
94
                                                   cache=opt.cache,
95
                                                   rank=LOCAL_RANK,
96
                                                   workers=nw)
97
98
    test_dir = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val'  # data/test or data/val
99
    if RANK in {-1, 0}:
100
        testloader = create_classification_dataloader(path=test_dir,
101
                                                      imgsz=imgsz,
102
                                                      batch_size=bs // WORLD_SIZE * 2,
103
                                                      augment=False,
104
                                                      cache=opt.cache,
105
                                                      rank=-1,
106
                                                      workers=nw)
107
108
    # Model
109
    with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
110
        if Path(opt.model).is_file() or opt.model.endswith('.pt'):
111
            model = attempt_load(opt.model, device='cpu', fuse=False)
112
        elif opt.model in torchvision.models.__dict__:  # TorchVision models i.e. resnet50, efficientnet_b0
113
            model = torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None)
114
        else:
115
            m = hub.list('ultralytics/yolov5')  # + hub.list('pytorch/vision')  # models
116
            raise ModuleNotFoundError(f'--model {opt.model} not found. Available models are: \n' + '\n'.join(m))
117
        if isinstance(model, DetectionModel):
118
            LOGGER.warning("WARNING ⚠️ pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
119
            model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10)  # convert to classification model
120
        reshape_classifier_output(model, nc)  # update class count
121
    for m in model.modules():
122
        if not pretrained and hasattr(m, 'reset_parameters'):
123
            m.reset_parameters()
124
        if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
125
            m.p = opt.dropout  # set dropout
126
    for p in model.parameters():
127
        p.requires_grad = True  # for training
128
    model = model.to(device)
129
130
    # Info
131
    if RANK in {-1, 0}:
132
        model.names = trainloader.dataset.classes  # attach class names
133
        model.transforms = testloader.dataset.torch_transforms  # attach inference transforms
134
        model_info(model)
135
        if opt.verbose:
136
            LOGGER.info(model)
137
        images, labels = next(iter(trainloader))
138
        file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
139
        logger.log_images(file, name='Train Examples')
140
        logger.log_graph(model, imgsz)  # log model
141
142
    # Optimizer
143
    optimizer = smart_optimizer(model, opt.optimizer, opt.lr0, momentum=0.9, decay=opt.decay)
144
145
    # Scheduler
146
    lrf = 0.01  # final lr (fraction of lr0)
147
    # lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf  # cosine
148
    lf = lambda x: (1 - x / epochs) * (1 - lrf) + lrf  # linear
149
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
150
    # scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr0, total_steps=epochs, pct_start=0.1,
151
    #                                    final_div_factor=1 / 25 / lrf)
152
153
    # EMA
154
    ema = ModelEMA(model) if RANK in {-1, 0} else None
155
156
    # DDP mode
157
    if cuda and RANK != -1:
158
        model = smart_DDP(model)
159
160
    # Train
161
    t0 = time.time()
162
    criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing)  # loss function
163
    best_fitness = 0.0
164
    scaler = amp.GradScaler(enabled=cuda)
165
    val = test_dir.stem  # 'val' or 'test'
166
    LOGGER.info(f'Image sizes {imgsz} train, {imgsz} test\n'
167
                f'Using {nw * WORLD_SIZE} dataloader workers\n'
168
                f"Logging results to {colorstr('bold', save_dir)}\n"
169
                f'Starting {opt.model} training on {data} dataset with {nc} classes for {epochs} epochs...\n\n'
170
                f"{'Epoch':>10}{'GPU_mem':>10}{'train_loss':>12}{f'{val}_loss':>12}{'top1_acc':>12}{'top5_acc':>12}")
171
    for epoch in range(epochs):  # loop over the dataset multiple times
172
        tloss, vloss, fitness = 0.0, 0.0, 0.0  # train loss, val loss, fitness
173
        model.train()
174
        if RANK != -1:
175
            trainloader.sampler.set_epoch(epoch)
176
        pbar = enumerate(trainloader)
177
        if RANK in {-1, 0}:
178
            pbar = tqdm(enumerate(trainloader), total=len(trainloader), bar_format=TQDM_BAR_FORMAT)
179
        for i, (images, labels) in pbar:  # progress bar
180
            images, labels = images.to(device, non_blocking=True), labels.to(device)
181
182
            # Forward
183
            with amp.autocast(enabled=cuda):  # stability issues when enabled
184
                loss = criterion(model(images), labels)
185
186
            # Backward
187
            scaler.scale(loss).backward()
188
189
            # Optimize
190
            scaler.unscale_(optimizer)  # unscale gradients
191
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)  # clip gradients
192
            scaler.step(optimizer)
193
            scaler.update()
194
            optimizer.zero_grad()
195
            if ema:
196
                ema.update(model)
197
198
            if RANK in {-1, 0}:
199
                # Print
200
                tloss = (tloss * i + loss.item()) / (i + 1)  # update mean losses
201
                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
202
                pbar.desc = f"{f'{epoch + 1}/{epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
203
204
                # Test
205
                if i == len(pbar) - 1:  # last batch
206
                    top1, top5, vloss = validate.run(model=ema.ema,
207
                                                     dataloader=testloader,
208
                                                     criterion=criterion,
209
                                                     pbar=pbar)  # test accuracy, loss
210
                    fitness = top1  # define fitness as top1 accuracy
211
212
        # Scheduler
213
        scheduler.step()
214
215
        # Log metrics
216
        if RANK in {-1, 0}:
217
            # Best fitness
218
            if fitness > best_fitness:
219
                best_fitness = fitness
220
221
            # Log
222
            metrics = {
223
                'train/loss': tloss,
224
                f'{val}/loss': vloss,
225
                'metrics/accuracy_top1': top1,
226
                'metrics/accuracy_top5': top5,
227
                'lr/0': optimizer.param_groups[0]['lr']}  # learning rate
228
            logger.log_metrics(metrics, epoch)
229
230
            # Save model
231
            final_epoch = epoch + 1 == epochs
232
            if (not opt.nosave) or final_epoch:
233
                ckpt = {
234
                    'epoch': epoch,
235
                    'best_fitness': best_fitness,
236
                    'model': deepcopy(ema.ema).half(),  # deepcopy(de_parallel(model)).half(),
237
                    'ema': None,  # deepcopy(ema.ema).half(),
238
                    'updates': ema.updates,
239
                    'optimizer': None,  # optimizer.state_dict(),
240
                    'opt': vars(opt),
241
                    'git': GIT_INFO,  # {remote, branch, commit} if a git repo
242
                    'date': datetime.now().isoformat()}
243
244
                # Save last, best and delete
245
                torch.save(ckpt, last)
246
                if best_fitness == fitness:
247
                    torch.save(ckpt, best)
248
                del ckpt
249
250
    # Train complete
251
    if RANK in {-1, 0} and final_epoch:
252
        LOGGER.info(f'\nTraining complete ({(time.time() - t0) / 3600:.3f} hours)'
253
                    f"\nResults saved to {colorstr('bold', save_dir)}"
254
                    f'\nPredict:         python classify/predict.py --weights {best} --source im.jpg'
255
                    f'\nValidate:        python classify/val.py --weights {best} --data {data_dir}'
256
                    f'\nExport:          python export.py --weights {best} --include onnx'
257
                    f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{best}')"
258
                    f'\nVisualize:       https://netron.app\n')
259
260
        # Plot examples
261
        images, labels = (x[:25] for x in next(iter(testloader)))  # first 25 images and labels
262
        pred = torch.max(ema.ema(images.to(device)), 1)[1]
263
        file = imshow_cls(images, labels, pred, de_parallel(model).names, verbose=False, f=save_dir / 'test_images.jpg')
264
265
        # Log results
266
        meta = {'epochs': epochs, 'top1_acc': best_fitness, 'date': datetime.now().isoformat()}
267
        logger.log_images(file, name='Test Examples (true-predicted)', epoch=epoch)
268
        logger.log_model(best, epochs, metadata=meta)
269
270
271
def parse_opt(known=False):
272
    parser = argparse.ArgumentParser()
273
    parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path')
274
    parser.add_argument('--data', type=str, default='imagenette160', help='cifar10, cifar100, mnist, imagenet, ...')
275
    parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
276
    parser.add_argument('--batch-size', type=int, default=64, help='total batch size for all GPUs')
277
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
278
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
279
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
280
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
281
    parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
282
    parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name')
283
    parser.add_argument('--name', default='exp', help='save to project/name')
284
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
285
    parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False')
286
    parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer')
287
    parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate')
288
    parser.add_argument('--decay', type=float, default=5e-5, help='weight decay')
289
    parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon')
290
    parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head')
291
    parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)')
292
    parser.add_argument('--verbose', action='store_true', help='Verbose mode')
293
    parser.add_argument('--seed', type=int, default=0, help='Global training seed')
294
    parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
295
    return parser.parse_known_args()[0] if known else parser.parse_args()
296
297
298
def main(opt):
299
    # Checks
300
    if RANK in {-1, 0}:
301
        print_args(vars(opt))
302
        check_git_status()
303
        check_requirements(ROOT / 'requirements.txt')
304
305
    # DDP mode
306
    device = select_device(opt.device, batch_size=opt.batch_size)
307
    if LOCAL_RANK != -1:
308
        assert opt.batch_size != -1, 'AutoBatch is coming soon for classification, please pass a valid --batch-size'
309
        assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
310
        assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
311
        torch.cuda.set_device(LOCAL_RANK)
312
        device = torch.device('cuda', LOCAL_RANK)
313
        dist.init_process_group(backend='nccl' if dist.is_nccl_available() else 'gloo')
314
315
    # Parameters
316
    opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)  # increment run
317
318
    # Train
319
    train(opt, device)
320
321
322
def run(**kwargs):
323
    # Usage: from yolov5 import classify; classify.train.run(data=mnist, imgsz=320, model='yolov5m')
324
    opt = parse_opt(True)
325
    for k, v in kwargs.items():
326
        setattr(opt, k, v)
327
    main(opt)
328
    return opt
329
330
331
if __name__ == '__main__':
332
    opt = parse_opt()
333
    main(opt)