Diff of /classify/val.py [000000] .. [190ca4]

Switch to unified view

a b/classify/val.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Validate a trained YOLOv5 classification model on a classification dataset
4
5
Usage:
6
    $ bash data/scripts/get_imagenet.sh --val  # download ImageNet val split (6.3G, 50000 images)
7
    $ python classify/val.py --weights yolov5m-cls.pt --data ../datasets/imagenet --img 224  # validate ImageNet
8
9
Usage - formats:
10
    $ python classify/val.py --weights yolov5s-cls.pt                 # PyTorch
11
                                       yolov5s-cls.torchscript        # TorchScript
12
                                       yolov5s-cls.onnx               # ONNX Runtime or OpenCV DNN with --dnn
13
                                       yolov5s-cls_openvino_model     # OpenVINO
14
                                       yolov5s-cls.engine             # TensorRT
15
                                       yolov5s-cls.mlmodel            # CoreML (macOS-only)
16
                                       yolov5s-cls_saved_model        # TensorFlow SavedModel
17
                                       yolov5s-cls.pb                 # TensorFlow GraphDef
18
                                       yolov5s-cls.tflite             # TensorFlow Lite
19
                                       yolov5s-cls_edgetpu.tflite     # TensorFlow Edge TPU
20
                                       yolov5s-cls_paddle_model       # PaddlePaddle
21
"""
22
23
import argparse
24
import os
25
import sys
26
from pathlib import Path
27
28
import torch
29
from tqdm import tqdm
30
31
FILE = Path(__file__).resolve()
32
ROOT = FILE.parents[1]  # YOLOv5 root directory
33
if str(ROOT) not in sys.path:
34
    sys.path.append(str(ROOT))  # add ROOT to PATH
35
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
36
37
from models.common import DetectMultiBackend
38
from utils.dataloaders import create_classification_dataloader
39
from utils.general import (LOGGER, TQDM_BAR_FORMAT, Profile, check_img_size, check_requirements, colorstr,
40
                           increment_path, print_args)
41
from utils.torch_utils import select_device, smart_inference_mode
42
43
44
@smart_inference_mode()
45
def run(
46
    data=ROOT / '../datasets/mnist',  # dataset dir
47
    weights=ROOT / 'yolov5s-cls.pt',  # model.pt path(s)
48
    batch_size=128,  # batch size
49
    imgsz=224,  # inference size (pixels)
50
    device='',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
51
    workers=8,  # max dataloader workers (per RANK in DDP mode)
52
    verbose=False,  # verbose output
53
    project=ROOT / 'runs/val-cls',  # save to project/name
54
    name='exp',  # save to project/name
55
    exist_ok=False,  # existing project/name ok, do not increment
56
    half=False,  # use FP16 half-precision inference
57
    dnn=False,  # use OpenCV DNN for ONNX inference
58
    model=None,
59
    dataloader=None,
60
    criterion=None,
61
    pbar=None,
62
):
63
    # Initialize/load model and set device
64
    training = model is not None
65
    if training:  # called by train.py
66
        device, pt, jit, engine = next(model.parameters()).device, True, False, False  # get model device, PyTorch model
67
        half &= device.type != 'cpu'  # half precision only supported on CUDA
68
        model.half() if half else model.float()
69
    else:  # called directly
70
        device = select_device(device, batch_size=batch_size)
71
72
        # Directories
73
        save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run
74
        save_dir.mkdir(parents=True, exist_ok=True)  # make dir
75
76
        # Load model
77
        model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
78
        stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
79
        imgsz = check_img_size(imgsz, s=stride)  # check image size
80
        half = model.fp16  # FP16 supported on limited backends with CUDA
81
        if engine:
82
            batch_size = model.batch_size
83
        else:
84
            device = model.device
85
            if not (pt or jit):
86
                batch_size = 1  # export.py models default to batch-size 1
87
                LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
88
89
        # Dataloader
90
        data = Path(data)
91
        test_dir = data / 'test' if (data / 'test').exists() else data / 'val'  # data/test or data/val
92
        dataloader = create_classification_dataloader(path=test_dir,
93
                                                      imgsz=imgsz,
94
                                                      batch_size=batch_size,
95
                                                      augment=False,
96
                                                      rank=-1,
97
                                                      workers=workers)
98
99
    model.eval()
100
    pred, targets, loss, dt = [], [], 0, (Profile(device=device), Profile(device=device), Profile(device=device))
101
    n = len(dataloader)  # number of batches
102
    action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
103
    desc = f'{pbar.desc[:-36]}{action:>36}' if pbar else f'{action}'
104
    bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
105
    with torch.cuda.amp.autocast(enabled=device.type != 'cpu'):
106
        for images, labels in bar:
107
            with dt[0]:
108
                images, labels = images.to(device, non_blocking=True), labels.to(device)
109
110
            with dt[1]:
111
                y = model(images)
112
113
            with dt[2]:
114
                pred.append(y.argsort(1, descending=True)[:, :5])
115
                targets.append(labels)
116
                if criterion:
117
                    loss += criterion(y, labels)
118
119
    loss /= n
120
    pred, targets = torch.cat(pred), torch.cat(targets)
121
    correct = (targets[:, None] == pred).float()
122
    acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1)  # (top1, top5) accuracy
123
    top1, top5 = acc.mean(0).tolist()
124
125
    if pbar:
126
        pbar.desc = f'{pbar.desc[:-36]}{loss:>12.3g}{top1:>12.3g}{top5:>12.3g}'
127
    if verbose:  # all classes
128
        LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
129
        LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
130
        for i, c in model.names.items():
131
            acc_i = acc[targets == i]
132
            top1i, top5i = acc_i.mean(0).tolist()
133
            LOGGER.info(f'{c:>24}{acc_i.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}')
134
135
        # Print results
136
        t = tuple(x.t / len(dataloader.dataset.samples) * 1E3 for x in dt)  # speeds per image
137
        shape = (1, 3, imgsz, imgsz)
138
        LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
139
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
140
141
    return top1, top5, loss
142
143
144
def parse_opt():
145
    parser = argparse.ArgumentParser()
146
    parser.add_argument('--data', type=str, default=ROOT / '../datasets/mnist', help='dataset path')
147
    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model.pt path(s)')
148
    parser.add_argument('--batch-size', type=int, default=128, help='batch size')
149
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='inference size (pixels)')
150
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
151
    parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
152
    parser.add_argument('--verbose', nargs='?', const=True, default=True, help='verbose output')
153
    parser.add_argument('--project', default=ROOT / 'runs/val-cls', help='save to project/name')
154
    parser.add_argument('--name', default='exp', help='save to project/name')
155
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
156
    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
157
    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
158
    opt = parser.parse_args()
159
    print_args(vars(opt))
160
    return opt
161
162
163
def main(opt):
164
    check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
165
    run(**vars(opt))
166
167
168
if __name__ == '__main__':
169
    opt = parse_opt()
170
    main(opt)