Diff of /utils/plots.py [000000] .. [190ca4]

Switch to unified view

a b/utils/plots.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Plotting utils
4
"""
5
6
import contextlib
7
import math
8
import os
9
from copy import copy
10
from pathlib import Path
11
12
import cv2
13
import matplotlib
14
import matplotlib.pyplot as plt
15
import numpy as np
16
import pandas as pd
17
import seaborn as sn
18
import torch
19
from PIL import Image, ImageDraw
20
from scipy.ndimage.filters import gaussian_filter1d
21
from ultralytics.utils.plotting import Annotator
22
23
from utils import TryExcept, threaded
24
from utils.general import LOGGER, clip_boxes, increment_path, xywh2xyxy, xyxy2xywh
25
from utils.metrics import fitness
26
27
# Settings
28
RANK = int(os.getenv('RANK', -1))
29
matplotlib.rc('font', **{'size': 11})
30
matplotlib.use('Agg')  # for writing to files only
31
32
33
class Colors:
34
    # Ultralytics color palette https://ultralytics.com/
35
    def __init__(self):
36
        # hex = matplotlib.colors.TABLEAU_COLORS.values()
37
        hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
38
                '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
39
        self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
40
        self.n = len(self.palette)
41
42
    def __call__(self, i, bgr=False):
43
        c = self.palette[int(i) % self.n]
44
        return (c[2], c[1], c[0]) if bgr else c
45
46
    @staticmethod
47
    def hex2rgb(h):  # rgb order (PIL)
48
        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
49
50
51
colors = Colors()  # create instance for 'from utils.plots import colors'
52
53
54
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
55
    """
56
    x:              Features to be visualized
57
    module_type:    Module type
58
    stage:          Module stage within model
59
    n:              Maximum number of feature maps to plot
60
    save_dir:       Directory to save results
61
    """
62
    if ('Detect'
63
            not in module_type) and ('Segment'
64
                                     not in module_type):  # 'Detect' for Object Detect task,'Segment' for Segment task
65
        batch, channels, height, width = x.shape  # batch, channels, height, width
66
        if height > 1 and width > 1:
67
            #f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png"
68
            f = str(stage)+"_"+"features.png"   # filename
69
70
            blocks = torch.chunk(x[0].cpu(), channels, dim=0)  # select batch index 0, block by channels
71
            n = min(n, channels)  # number of plots
72
            fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)  # 8 rows x n/8 cols
73
            ax = ax.ravel()
74
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
75
            for i in range(n):
76
                ax[i].imshow(blocks[i].detach().numpy().squeeze()*256)  # cmap='gray'
77
                ax[i].axis('off')
78
79
            LOGGER.info(f'Saving {f}... ({n}/{channels})')
80
            plt.savefig(f, dpi=300, bbox_inches='tight')
81
            plt.close()
82
            #np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy())  # npy save
83
84
85
def hist2d(x, y, n=100):
86
    # 2d histogram used in labels.png and evolve.png
87
    xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
88
    hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
89
    xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
90
    yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
91
    return np.log(hist[xidx, yidx])
92
93
94
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
95
    from scipy.signal import butter, filtfilt
96
97
    # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
98
    def butter_lowpass(cutoff, fs, order):
99
        nyq = 0.5 * fs
100
        normal_cutoff = cutoff / nyq
101
        return butter(order, normal_cutoff, btype='low', analog=False)
102
103
    b, a = butter_lowpass(cutoff, fs, order=order)
104
    return filtfilt(b, a, data)  # forward-backward filter
105
106
107
def output_to_target(output, max_det=300):
108
    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
109
    targets = []
110
    for i, o in enumerate(output):
111
        box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
112
        j = torch.full((conf.shape[0], 1), i)
113
        targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
114
    return torch.cat(targets, 0).numpy()
115
116
117
@threaded
118
def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
119
    # Plot image grid with labels
120
    if isinstance(images, torch.Tensor):
121
        images = images.cpu().float().numpy()
122
    if isinstance(targets, torch.Tensor):
123
        targets = targets.cpu().numpy()
124
125
    max_size = 1920  # max image size
126
    max_subplots = 16  # max image subplots, i.e. 4x4
127
    bs, _, h, w = images.shape  # batch size, _, height, width
128
    bs = min(bs, max_subplots)  # limit plot images
129
    ns = np.ceil(bs ** 0.5)  # number of subplots (square)
130
    if np.max(images[0]) <= 1:
131
        images *= 255  # de-normalise (optional)
132
133
    # Build Image
134
    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
135
    for i, im in enumerate(images):
136
        if i == max_subplots:  # if last batch has fewer images than we expect
137
            break
138
        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
139
        im = im.transpose(1, 2, 0)
140
        mosaic[y:y + h, x:x + w, :] = im
141
142
    # Resize (optional)
143
    scale = max_size / ns / max(h, w)
144
    if scale < 1:
145
        h = math.ceil(scale * h)
146
        w = math.ceil(scale * w)
147
        mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
148
149
    # Annotate
150
    fs = int((h + w) * ns * 0.01)  # font size
151
    annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
152
    for i in range(i + 1):
153
        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
154
        annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2)  # borders
155
        if paths:
156
            annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
157
        if len(targets) > 0:
158
            ti = targets[targets[:, 0] == i]  # image targets
159
            boxes = xywh2xyxy(ti[:, 2:6]).T
160
            classes = ti[:, 1].astype('int')
161
            labels = ti.shape[1] == 6  # labels if no conf column
162
            conf = None if labels else ti[:, 6]  # check for confidence presence (label vs pred)
163
164
            if boxes.shape[1]:
165
                if boxes.max() <= 1.01:  # if normalized with tolerance 0.01
166
                    boxes[[0, 2]] *= w  # scale to pixels
167
                    boxes[[1, 3]] *= h
168
                elif scale < 1:  # absolute coords need scale if image scales
169
                    boxes *= scale
170
            boxes[[0, 2]] += x
171
            boxes[[1, 3]] += y
172
            for j, box in enumerate(boxes.T.tolist()):
173
                cls = classes[j]
174
                color = colors(cls)
175
                cls = names[cls] if names else cls
176
                if labels or conf[j] > 0.25:  # 0.25 conf thresh
177
                    label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
178
                    annotator.box_label(box, label, color=color)
179
    annotator.im.save(fname)  # save
180
181
182
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
183
    # Plot LR simulating training for full epochs
184
    optimizer, scheduler = copy(optimizer), copy(scheduler)  # do not modify originals
185
    y = []
186
    for _ in range(epochs):
187
        scheduler.step()
188
        y.append(optimizer.param_groups[0]['lr'])
189
    plt.plot(y, '.-', label='LR')
190
    plt.xlabel('epoch')
191
    plt.ylabel('LR')
192
    plt.grid()
193
    plt.xlim(0, epochs)
194
    plt.ylim(0)
195
    plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
196
    plt.close()
197
198
199
def plot_val_txt():  # from utils.plots import *; plot_val()
200
    # Plot val.txt histograms
201
    x = np.loadtxt('val.txt', dtype=np.float32)
202
    box = xyxy2xywh(x[:, :4])
203
    cx, cy = box[:, 0], box[:, 1]
204
205
    fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
206
    ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
207
    ax.set_aspect('equal')
208
    plt.savefig('hist2d.png', dpi=300)
209
210
    fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
211
    ax[0].hist(cx, bins=600)
212
    ax[1].hist(cy, bins=600)
213
    plt.savefig('hist1d.png', dpi=200)
214
215
216
def plot_targets_txt():  # from utils.plots import *; plot_targets_txt()
217
    # Plot targets.txt histograms
218
    x = np.loadtxt('targets.txt', dtype=np.float32).T
219
    s = ['x targets', 'y targets', 'width targets', 'height targets']
220
    fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
221
    ax = ax.ravel()
222
    for i in range(4):
223
        ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
224
        ax[i].legend()
225
        ax[i].set_title(s[i])
226
    plt.savefig('targets.jpg', dpi=200)
227
228
229
def plot_val_study(file='', dir='', x=None):  # from utils.plots import *; plot_val_study()
230
    # Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
231
    save_dir = Path(file).parent if file else Path(dir)
232
    plot2 = False  # plot additional results
233
    if plot2:
234
        ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
235
236
    fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
237
    # for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
238
    for f in sorted(save_dir.glob('study*.txt')):
239
        y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
240
        x = np.arange(y.shape[1]) if x is None else np.array(x)
241
        if plot2:
242
            s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
243
            for i in range(7):
244
                ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
245
                ax[i].set_title(s[i])
246
247
        j = y[3].argmax() + 1
248
        ax2.plot(y[5, 1:j],
249
                 y[3, 1:j] * 1E2,
250
                 '.-',
251
                 linewidth=2,
252
                 markersize=8,
253
                 label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
254
255
    ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
256
             'k.-',
257
             linewidth=2,
258
             markersize=8,
259
             alpha=.25,
260
             label='EfficientDet')
261
262
    ax2.grid(alpha=0.2)
263
    ax2.set_yticks(np.arange(20, 60, 5))
264
    ax2.set_xlim(0, 57)
265
    ax2.set_ylim(25, 55)
266
    ax2.set_xlabel('GPU Speed (ms/img)')
267
    ax2.set_ylabel('COCO AP val')
268
    ax2.legend(loc='lower right')
269
    f = save_dir / 'study.png'
270
    print(f'Saving {f}...')
271
    plt.savefig(f, dpi=300)
272
273
274
@TryExcept()  # known issue https://github.com/ultralytics/yolov5/issues/5395
275
def plot_labels(labels, names=(), save_dir=Path('')):
276
    # plot dataset labels
277
    LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
278
    c, b = labels[:, 0], labels[:, 1:].transpose()  # classes, boxes
279
    nc = int(c.max() + 1)  # number of classes
280
    x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
281
282
    # seaborn correlogram
283
    sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
284
    plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
285
    plt.close()
286
287
    # matplotlib labels
288
    matplotlib.use('svg')  # faster
289
    ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
290
    y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
291
    with contextlib.suppress(Exception):  # color histogram bars by class
292
        [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)]  # known issue #3195
293
    ax[0].set_ylabel('instances')
294
    if 0 < len(names) < 30:
295
        ax[0].set_xticks(range(len(names)))
296
        ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
297
    else:
298
        ax[0].set_xlabel('classes')
299
    sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
300
    sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
301
302
    # rectangles
303
    labels[:, 1:3] = 0.5  # center
304
    labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
305
    img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
306
    for cls, *box in labels[:1000]:
307
        ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls))  # plot
308
    ax[1].imshow(img)
309
    ax[1].axis('off')
310
311
    for a in [0, 1, 2, 3]:
312
        for s in ['top', 'right', 'left', 'bottom']:
313
            ax[a].spines[s].set_visible(False)
314
315
    plt.savefig(save_dir / 'labels.jpg', dpi=200)
316
    matplotlib.use('Agg')
317
    plt.close()
318
319
320
def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
321
    # Show classification image grid with labels (optional) and predictions (optional)
322
    from utils.augmentations import denormalize
323
324
    names = names or [f'class{i}' for i in range(1000)]
325
    blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
326
                         dim=0)  # select batch index 0, block by channels
327
    n = min(len(blocks), nmax)  # number of plots
328
    m = min(8, round(n ** 0.5))  # 8 x 8 default
329
    fig, ax = plt.subplots(math.ceil(n / m), m)  # 8 rows x n/8 cols
330
    ax = ax.ravel() if m > 1 else [ax]
331
    # plt.subplots_adjust(wspace=0.05, hspace=0.05)
332
    for i in range(n):
333
        ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
334
        ax[i].axis('off')
335
        if labels is not None:
336
            s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
337
            ax[i].set_title(s, fontsize=8, verticalalignment='top')
338
    plt.savefig(f, dpi=300, bbox_inches='tight')
339
    plt.close()
340
    if verbose:
341
        LOGGER.info(f'Saving {f}')
342
        if labels is not None:
343
            LOGGER.info('True:     ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
344
        if pred is not None:
345
            LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
346
    return f
347
348
349
def plot_evolve(evolve_csv='path/to/evolve.csv'):  # from utils.plots import *; plot_evolve()
350
    # Plot evolve.csv hyp evolution results
351
    evolve_csv = Path(evolve_csv)
352
    data = pd.read_csv(evolve_csv)
353
    keys = [x.strip() for x in data.columns]
354
    x = data.values
355
    f = fitness(x)
356
    j = np.argmax(f)  # max fitness index
357
    plt.figure(figsize=(10, 12), tight_layout=True)
358
    matplotlib.rc('font', **{'size': 8})
359
    print(f'Best results from row {j} of {evolve_csv}:')
360
    for i, k in enumerate(keys[7:]):
361
        v = x[:, 7 + i]
362
        mu = v[j]  # best single result
363
        plt.subplot(6, 5, i + 1)
364
        plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
365
        plt.plot(mu, f.max(), 'k+', markersize=15)
366
        plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9})  # limit to 40 characters
367
        if i % 5 != 0:
368
            plt.yticks([])
369
        print(f'{k:>15}: {mu:.3g}')
370
    f = evolve_csv.with_suffix('.png')  # filename
371
    plt.savefig(f, dpi=200)
372
    plt.close()
373
    print(f'Saved {f}')
374
375
376
def plot_results(file='path/to/results.csv', dir=''):
377
    # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
378
    save_dir = Path(file).parent if file else Path(dir)
379
    fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
380
    ax = ax.ravel()
381
    files = list(save_dir.glob('results*.csv'))
382
    assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
383
    for f in files:
384
        try:
385
            data = pd.read_csv(f)
386
            s = [x.strip() for x in data.columns]
387
            x = data.values[:, 0]
388
            for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
389
                y = data.values[:, j].astype('float')
390
                # y[y == 0] = np.nan  # don't show zero values
391
                ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)  # actual results
392
                ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2)  # smoothing line
393
                ax[i].set_title(s[j], fontsize=12)
394
                # if j in [8, 9, 10]:  # share train and val loss y axes
395
                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
396
        except Exception as e:
397
            LOGGER.info(f'Warning: Plotting error for {f}: {e}')
398
    ax[1].legend()
399
    fig.savefig(save_dir / 'results.png', dpi=200)
400
    plt.close()
401
402
403
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
404
    # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
405
    ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
406
    s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
407
    files = list(Path(save_dir).glob('frames*.txt'))
408
    for fi, f in enumerate(files):
409
        try:
410
            results = np.loadtxt(f, ndmin=2).T[:, 90:-30]  # clip first and last rows
411
            n = results.shape[1]  # number of rows
412
            x = np.arange(start, min(stop, n) if stop else n)
413
            results = results[:, x]
414
            t = (results[0] - results[0].min())  # set t0=0s
415
            results[0] = x
416
            for i, a in enumerate(ax):
417
                if i < len(results):
418
                    label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
419
                    a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
420
                    a.set_title(s[i])
421
                    a.set_xlabel('time (s)')
422
                    # if fi == len(files) - 1:
423
                    #     a.set_ylim(bottom=0)
424
                    for side in ['top', 'right']:
425
                        a.spines[side].set_visible(False)
426
                else:
427
                    a.remove()
428
        except Exception as e:
429
            print(f'Warning: Plotting error for {f}; {e}')
430
    ax[1].legend()
431
    plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
432
433
434
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
435
    # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
436
    xyxy = torch.tensor(xyxy).view(-1, 4)
437
    b = xyxy2xywh(xyxy)  # boxes
438
    if square:
439
        b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # attempt rectangle to square
440
    b[:, 2:] = b[:, 2:] * gain + pad  # box wh * gain + pad
441
    xyxy = xywh2xyxy(b).long()
442
    clip_boxes(xyxy, im.shape)
443
    crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
444
    if save:
445
        file.parent.mkdir(parents=True, exist_ok=True)  # make directory
446
        f = str(increment_path(file).with_suffix('.jpg'))
447
        # cv2.imwrite(f, crop)  # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
448
        Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0)  # save RGB
449
    return crop