Switch to unified view

a b/landmark_extraction/utils/plots.py
1
# Plotting utils
2
3
import glob
4
import math
5
import os
6
import random
7
from copy import copy
8
from pathlib import Path
9
10
import cv2
11
import matplotlib
12
import matplotlib.pyplot as plt
13
import numpy as np
14
import pandas as pd
15
import seaborn as sns
16
import torch
17
import yaml
18
from PIL import Image, ImageDraw, ImageFont
19
from scipy.signal import butter, filtfilt
20
21
from utils.general import xywh2xyxy, xyxy2xywh
22
from utils.metrics import fitness
23
24
# Settings
25
matplotlib.rc('font', **{'size': 11})
26
matplotlib.use('Agg')  # for writing to files only
27
28
class Colors:
29
    # Ultralytics color palette https://ultralytics.com/
30
    def __init__(self):
31
        self.palette = [self.hex2rgb(c) for c in matplotlib.colors.TABLEAU_COLORS.values()]
32
        self.n = len(self.palette)
33
34
    def __call__(self, i, bgr=False):
35
        c = self.palette[int(i) % self.n]
36
        return (c[2], c[1], c[0]) if bgr else c
37
38
    @staticmethod
39
    def hex2rgb(h):  # rgb order (PIL)
40
        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
41
42
43
def plot_one_box_kpt(x, im, color=None, label=None, line_thickness=3, kpt_label=False, kpts=None, steps=2, orig_shape=None):
44
    # Plots one bounding box on image 'im' using OpenCV
45
    assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
46
    tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1  # line/font thickness
47
    color = color or [random.randint(0, 255) for _ in range(3)]
48
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
49
    cv2.rectangle(im, c1, c2, (255,0,0), thickness=tl*1//3, lineType=cv2.LINE_AA)
50
    if label:
51
        if len(label.split(' ')) > 1:
52
            label = label.split(' ')[-1]
53
            tf = max(tl - 1, 1)  # font thickness
54
            t_size = cv2.getTextSize(label, 0, fontScale=tl / 6, thickness=tf)[0]
55
            c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
56
            cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA)  # filled
57
            cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 6, [225, 255, 255], thickness=tf//2, lineType=cv2.LINE_AA)
58
    if kpt_label:
59
        plot_skeleton_kpts(im, kpts, steps, orig_shape=orig_shape)
60
61
colors = Colors()  
62
63
def color_list():
64
    def hex2rgb(h):
65
        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
66
67
    return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()]  # or BASE_ (8), CSS4_ (148), XKCD_ (949)
68
69
70
def hist2d(x, y, n=100):
71
    # 2d histogram used in labels.png and evolve.png
72
    xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
73
    hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
74
    xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
75
    yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
76
    return np.log(hist[xidx, yidx])
77
78
79
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
80
    # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
81
    def butter_lowpass(cutoff, fs, order):
82
        nyq = 0.5 * fs
83
        normal_cutoff = cutoff / nyq
84
        return butter(order, normal_cutoff, btype='low', analog=False)
85
86
    b, a = butter_lowpass(cutoff, fs, order=order)
87
    return filtfilt(b, a, data)  # forward-backward filter
88
89
90
def plot_one_box(x, img, color=None, label=None, line_thickness=1):
91
    # Plots one bounding box on image img
92
    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 2  # line/font thickness
93
    color = color or [random.randint(0, 255) for _ in range(3)]
94
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
95
    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
96
    if label:
97
        tf = max(tl - 1, 1)  # font thickness
98
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
99
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
100
        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
101
        cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
102
103
104
def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None):
105
    img = Image.fromarray(img)
106
    draw = ImageDraw.Draw(img)
107
    line_thickness = line_thickness or max(int(min(img.size) / 200), 2)
108
    draw.rectangle(box, width=line_thickness, outline=tuple(color))  # plot
109
    if label:
110
        fontsize = max(round(max(img.size) / 40), 12)
111
        font = ImageFont.truetype("Arial.ttf", fontsize)
112
        txt_width, txt_height = font.getsize(label)
113
        draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color))
114
        draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
115
    return np.asarray(img)
116
117
118
def plot_wh_methods():  # from utils.plots import *; plot_wh_methods()
119
    # Compares the two methods for width-height anchor multiplication
120
    # https://github.com/ultralytics/yolov3/issues/168
121
    x = np.arange(-4.0, 4.0, .1)
122
    ya = np.exp(x)
123
    yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
124
125
    fig = plt.figure(figsize=(6, 3), tight_layout=True)
126
    plt.plot(x, ya, '.-', label='YOLOv3')
127
    plt.plot(x, yb ** 2, '.-', label='YOLOR ^2')
128
    plt.plot(x, yb ** 1.6, '.-', label='YOLOR ^1.6')
129
    plt.xlim(left=-4, right=4)
130
    plt.ylim(bottom=0, top=6)
131
    plt.xlabel('input')
132
    plt.ylabel('output')
133
    plt.grid()
134
    plt.legend()
135
    fig.savefig('comparison.png', dpi=200)
136
137
138
def output_to_target(output):
139
    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
140
    targets = []
141
    for i, o in enumerate(output):
142
        for *box, conf, cls in o.cpu().numpy():
143
            targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
144
    return np.array(targets)
145
146
147
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
148
    # Plot image grid with labels
149
150
    if isinstance(images, torch.Tensor):
151
        images = images.cpu().float().numpy()
152
    if isinstance(targets, torch.Tensor):
153
        targets = targets.cpu().numpy()
154
155
    # un-normalise
156
    if np.max(images[0]) <= 1:
157
        images *= 255
158
159
    tl = 3  # line thickness
160
    tf = max(tl - 1, 1)  # font thickness
161
    bs, _, h, w = images.shape  # batch size, _, height, width
162
    bs = min(bs, max_subplots)  # limit plot images
163
    ns = np.ceil(bs ** 0.5)  # number of subplots (square)
164
165
    # Check if we should resize
166
    scale_factor = max_size / max(h, w)
167
    if scale_factor < 1:
168
        h = math.ceil(scale_factor * h)
169
        w = math.ceil(scale_factor * w)
170
171
    colors = color_list()  # list of colors
172
    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
173
    for i, img in enumerate(images):
174
        if i == max_subplots:  # if last batch has fewer images than we expect
175
            break
176
177
        block_x = int(w * (i // ns))
178
        block_y = int(h * (i % ns))
179
180
        img = img.transpose(1, 2, 0)
181
        if scale_factor < 1:
182
            img = cv2.resize(img, (w, h))
183
184
        mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
185
        if len(targets) > 0:
186
            image_targets = targets[targets[:, 0] == i]
187
            boxes = xywh2xyxy(image_targets[:, 2:6]).T
188
            classes = image_targets[:, 1].astype('int')
189
            labels = image_targets.shape[1] == 6  # labels if no conf column
190
            conf = None if labels else image_targets[:, 6]  # check for confidence presence (label vs pred)
191
192
            if boxes.shape[1]:
193
                if boxes.max() <= 1.01:  # if normalized with tolerance 0.01
194
                    boxes[[0, 2]] *= w  # scale to pixels
195
                    boxes[[1, 3]] *= h
196
                elif scale_factor < 1:  # absolute coords need scale if image scales
197
                    boxes *= scale_factor
198
            boxes[[0, 2]] += block_x
199
            boxes[[1, 3]] += block_y
200
            for j, box in enumerate(boxes.T):
201
                cls = int(classes[j])
202
                color = colors[cls % len(colors)]
203
                cls = names[cls] if names else cls
204
                if labels or conf[j] > 0.25:  # 0.25 conf thresh
205
                    label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
206
                    plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
207
208
        # Draw image filename labels
209
        if paths:
210
            label = Path(paths[i]).name[:40]  # trim to 40 char
211
            t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
212
            cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
213
                        lineType=cv2.LINE_AA)
214
215
        # Image border
216
        cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
217
218
    if fname:
219
        r = min(1280. / max(h, w) / ns, 1.0)  # ratio to limit image size
220
        mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
221
        # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))  # cv2 save
222
        Image.fromarray(mosaic).save(fname)  # PIL save
223
    return mosaic
224
225
226
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
227
    # Plot LR simulating training for full epochs
228
    optimizer, scheduler = copy(optimizer), copy(scheduler)  # do not modify originals
229
    y = []
230
    for _ in range(epochs):
231
        scheduler.step()
232
        y.append(optimizer.param_groups[0]['lr'])
233
    plt.plot(y, '.-', label='LR')
234
    plt.xlabel('epoch')
235
    plt.ylabel('LR')
236
    plt.grid()
237
    plt.xlim(0, epochs)
238
    plt.ylim(0)
239
    plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
240
    plt.close()
241
242
243
def plot_test_txt():  # from utils.plots import *; plot_test()
244
    # Plot test.txt histograms
245
    x = np.loadtxt('test.txt', dtype=np.float32)
246
    box = xyxy2xywh(x[:, :4])
247
    cx, cy = box[:, 0], box[:, 1]
248
249
    fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
250
    ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
251
    ax.set_aspect('equal')
252
    plt.savefig('hist2d.png', dpi=300)
253
254
    fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
255
    ax[0].hist(cx, bins=600)
256
    ax[1].hist(cy, bins=600)
257
    plt.savefig('hist1d.png', dpi=200)
258
259
260
def plot_targets_txt():  # from utils.plots import *; plot_targets_txt()
261
    # Plot targets.txt histograms
262
    x = np.loadtxt('targets.txt', dtype=np.float32).T
263
    s = ['x targets', 'y targets', 'width targets', 'height targets']
264
    fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
265
    ax = ax.ravel()
266
    for i in range(4):
267
        ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
268
        ax[i].legend()
269
        ax[i].set_title(s[i])
270
    plt.savefig('targets.jpg', dpi=200)
271
272
273
def plot_study_txt(path='', x=None):  # from utils.plots import *; plot_study_txt()
274
    # Plot study.txt generated by test.py
275
    fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
276
    # ax = ax.ravel()
277
278
    fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
279
    # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolor-p6', 'yolor-w6', 'yolor-e6', 'yolor-d6']]:
280
    for f in sorted(Path(path).glob('study*.txt')):
281
        y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
282
        x = np.arange(y.shape[1]) if x is None else np.array(x)
283
        s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
284
        # for i in range(7):
285
        #     ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
286
        #     ax[i].set_title(s[i])
287
288
        j = y[3].argmax() + 1
289
        ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
290
                 label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
291
292
    ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
293
             'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
294
295
    ax2.grid(alpha=0.2)
296
    ax2.set_yticks(np.arange(20, 60, 5))
297
    ax2.set_xlim(0, 57)
298
    ax2.set_ylim(30, 55)
299
    ax2.set_xlabel('GPU Speed (ms/img)')
300
    ax2.set_ylabel('COCO AP val')
301
    ax2.legend(loc='lower right')
302
    plt.savefig(str(Path(path).name) + '.png', dpi=300)
303
304
305
def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
306
    # plot dataset labels
307
    print('Plotting labels... ')
308
    c, b = labels[:, 0], labels[:, 1:].transpose()  # classes, boxes
309
    nc = int(c.max() + 1)  # number of classes
310
    colors = color_list()
311
    x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
312
313
    # seaborn correlogram
314
    sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
315
    plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
316
    plt.close()
317
318
    # matplotlib labels
319
    matplotlib.use('svg')  # faster
320
    ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
321
    ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
322
    ax[0].set_ylabel('instances')
323
    if 0 < len(names) < 30:
324
        ax[0].set_xticks(range(len(names)))
325
        ax[0].set_xticklabels(names, rotation=90, fontsize=10)
326
    else:
327
        ax[0].set_xlabel('classes')
328
    sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
329
    sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
330
331
    # rectangles
332
    labels[:, 1:3] = 0.5  # center
333
    labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
334
    img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
335
    for cls, *box in labels[:1000]:
336
        ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10])  # plot
337
    ax[1].imshow(img)
338
    ax[1].axis('off')
339
340
    for a in [0, 1, 2, 3]:
341
        for s in ['top', 'right', 'left', 'bottom']:
342
            ax[a].spines[s].set_visible(False)
343
344
    plt.savefig(save_dir / 'labels.jpg', dpi=200)
345
    matplotlib.use('Agg')
346
    plt.close()
347
348
    # loggers
349
    for k, v in loggers.items() or {}:
350
        if k == 'wandb' and v:
351
            v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
352
353
354
def plot_evolution(yaml_file='data/hyp.finetune.yaml'):  # from utils.plots import *; plot_evolution()
355
    # Plot hyperparameter evolution results in evolve.txt
356
    with open(yaml_file) as f:
357
        hyp = yaml.load(f, Loader=yaml.SafeLoader)
358
    x = np.loadtxt('evolve.txt', ndmin=2)
359
    f = fitness(x)
360
    # weights = (f - f.min()) ** 2  # for weighted results
361
    plt.figure(figsize=(10, 12), tight_layout=True)
362
    matplotlib.rc('font', **{'size': 8})
363
    for i, (k, v) in enumerate(hyp.items()):
364
        y = x[:, i + 7]
365
        # mu = (y * weights).sum() / weights.sum()  # best weighted result
366
        mu = y[f.argmax()]  # best single result
367
        plt.subplot(6, 5, i + 1)
368
        plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
369
        plt.plot(mu, f.max(), 'k+', markersize=15)
370
        plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9})  # limit to 40 characters
371
        if i % 5 != 0:
372
            plt.yticks([])
373
        print('%15s: %.3g' % (k, mu))
374
    plt.savefig('evolve.png', dpi=200)
375
    print('\nPlot saved as evolve.png')
376
377
378
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
379
    # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
380
    ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
381
    s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
382
    files = list(Path(save_dir).glob('frames*.txt'))
383
    for fi, f in enumerate(files):
384
        try:
385
            results = np.loadtxt(f, ndmin=2).T[:, 90:-30]  # clip first and last rows
386
            n = results.shape[1]  # number of rows
387
            x = np.arange(start, min(stop, n) if stop else n)
388
            results = results[:, x]
389
            t = (results[0] - results[0].min())  # set t0=0s
390
            results[0] = x
391
            for i, a in enumerate(ax):
392
                if i < len(results):
393
                    label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
394
                    a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
395
                    a.set_title(s[i])
396
                    a.set_xlabel('time (s)')
397
                    # if fi == len(files) - 1:
398
                    #     a.set_ylim(bottom=0)
399
                    for side in ['top', 'right']:
400
                        a.spines[side].set_visible(False)
401
                else:
402
                    a.remove()
403
        except Exception as e:
404
            print('Warning: Plotting error for %s; %s' % (f, e))
405
406
    ax[1].legend()
407
    plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
408
409
410
def plot_results_overlay(start=0, stop=0):  # from utils.plots import *; plot_results_overlay()
411
    # Plot training 'results*.txt', overlaying train and val losses
412
    s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95']  # legends
413
    t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1']  # titles
414
    for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
415
        results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
416
        n = results.shape[1]  # number of rows
417
        x = range(start, min(stop, n) if stop else n)
418
        fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
419
        ax = ax.ravel()
420
        for i in range(5):
421
            for j in [i, i + 5]:
422
                y = results[j, x]
423
                ax[i].plot(x, y, marker='.', label=s[j])
424
                # y_smooth = butter_lowpass_filtfilt(y)
425
                # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
426
427
            ax[i].set_title(t[i])
428
            ax[i].legend()
429
            ax[i].set_ylabel(f) if i == 0 else None  # add filename
430
        fig.savefig(f.replace('.txt', '.png'), dpi=200)
431
432
433
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
434
    # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
435
    fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
436
    ax = ax.ravel()
437
    s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
438
         'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
439
    if bucket:
440
        # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
441
        files = ['results%g.txt' % x for x in id]
442
        c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
443
        os.system(c)
444
    else:
445
        files = list(Path(save_dir).glob('results*.txt'))
446
    assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
447
    for fi, f in enumerate(files):
448
        try:
449
            results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
450
            n = results.shape[1]  # number of rows
451
            x = range(start, min(stop, n) if stop else n)
452
            for i in range(10):
453
                y = results[i, x]
454
                if i in [0, 1, 2, 5, 6, 7]:
455
                    y[y == 0] = np.nan  # don't show zero loss values
456
                    # y /= y[0]  # normalize
457
                label = labels[fi] if len(labels) else f.stem
458
                ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
459
                ax[i].set_title(s[i])
460
                # if i in [5, 6, 7]:  # share train and val loss y axes
461
                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
462
        except Exception as e:
463
            print('Warning: Plotting error for %s; %s' % (f, e))
464
465
    ax[1].legend()
466
    fig.savefig(Path(save_dir) / 'results.png', dpi=200)
467
    
468
    
469
def output_to_keypoint(output):
470
    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
471
    targets = []
472
    for i, o in enumerate(output):
473
        kpts = o[:,6:]
474
        o = o[:,:6]
475
        for index, (*box, conf, cls) in enumerate(o.cpu().numpy()):
476
            targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf, *list(kpts.cpu().numpy()[index])])
477
    return np.array(targets)
478
479
480
def plot_skeleton_kpts(im, kpts, steps, orig_shape=None):
481
    #Plot the skeleton and keypointsfor coco datatset
482
    palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
483
                        [230, 230, 0], [255, 153, 255], [153, 204, 255],
484
                        [255, 102, 255], [255, 51, 255], [102, 178, 255],
485
                        [51, 153, 255], [255, 153, 153], [255, 102, 102],
486
                        [255, 51, 51], [153, 255, 153], [102, 255, 102],
487
                        [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0],
488
                        [255, 255, 255]])
489
490
    skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],
491
                [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
492
                [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
493
494
    pose_limb_color = palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
495
    pose_kpt_color = palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
496
    radius = 5
497
    num_kpts = len(kpts) // steps
498
499
    for kid in range(num_kpts):
500
        r, g, b = pose_kpt_color[kid]
501
        x_coord, y_coord = kpts[steps * kid], kpts[steps * kid + 1]
502
        if not (x_coord % 640 == 0 or y_coord % 640 == 0):
503
            if steps == 3:
504
                conf = kpts[steps * kid + 2]
505
                if conf < 0.5:
506
                    continue
507
            cv2.circle(im, (int(x_coord), int(y_coord)), radius, (int(r), int(g), int(b)), -1)
508
509
    for sk_id, sk in enumerate(skeleton):
510
        r, g, b = pose_limb_color[sk_id]
511
        pos1 = (int(kpts[(sk[0]-1)*steps]), int(kpts[(sk[0]-1)*steps+1]))
512
        pos2 = (int(kpts[(sk[1]-1)*steps]), int(kpts[(sk[1]-1)*steps+1]))
513
        if steps == 3:
514
            conf1 = kpts[(sk[0]-1)*steps+2]
515
            conf2 = kpts[(sk[1]-1)*steps+2]
516
            if conf1<0.5 or conf2<0.5:
517
                continue
518
        if pos1[0]%640 == 0 or pos1[1]%640==0 or pos1[0]<0 or pos1[1]<0:
519
            continue
520
        if pos2[0] % 640 == 0 or pos2[1] % 640 == 0 or pos2[0]<0 or pos2[1]<0:
521
            continue
522
        cv2.line(im, pos1, pos2, (int(r), int(g), int(b)), thickness=2)