|
a |
|
b/darkflow/net/yolo/predict.py |
|
|
1 |
from ...utils.im_transform import imcv2_recolor, imcv2_affine_trans |
|
|
2 |
from ...utils.box import BoundBox, box_iou, prob_compare |
|
|
3 |
import numpy as np |
|
|
4 |
import cv2 |
|
|
5 |
import os |
|
|
6 |
import json |
|
|
7 |
from ...cython_utils.cy_yolo_findboxes import yolo_box_constructor |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def _fix(obj, dims, scale, offs): |
|
|
11 |
for i in range(1, 5): |
|
|
12 |
dim = dims[(i + 1) % 2] |
|
|
13 |
off = offs[(i + 1) % 2] |
|
|
14 |
obj[i] = int(obj[i] * scale - off) |
|
|
15 |
obj[i] = max(min(obj[i], dim), 0) |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def resize_input(self, im): |
|
|
19 |
h, w, c = self.meta['inp_size'] |
|
|
20 |
imsz = cv2.resize(im, (w, h)) |
|
|
21 |
imsz = imsz / 255. |
|
|
22 |
imsz = imsz[:, :, ::-1] |
|
|
23 |
return imsz |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def process_box(self, b, h, w, threshold): |
|
|
27 |
max_indx = np.argmax(b.probs) |
|
|
28 |
max_prob = b.probs[max_indx] |
|
|
29 |
label = self.meta['labels'][max_indx] |
|
|
30 |
if max_prob > threshold: |
|
|
31 |
left = int((b.x - b.w / 2.) * w) |
|
|
32 |
right = int((b.x + b.w / 2.) * w) |
|
|
33 |
top = int((b.y - b.h / 2.) * h) |
|
|
34 |
bot = int((b.y + b.h / 2.) * h) |
|
|
35 |
if left < 0: left = 0 |
|
|
36 |
if right > w - 1: right = w - 1 |
|
|
37 |
if top < 0: top = 0 |
|
|
38 |
if bot > h - 1: bot = h - 1 |
|
|
39 |
mess = '{}'.format(label) |
|
|
40 |
return (left, right, top, bot, mess, max_indx, max_prob) |
|
|
41 |
return None |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def findboxes(self, net_out): |
|
|
45 |
meta, FLAGS = self.meta, self.FLAGS |
|
|
46 |
threshold = FLAGS.threshold |
|
|
47 |
|
|
|
48 |
boxes = [] |
|
|
49 |
boxes = yolo_box_constructor(meta, net_out, threshold) |
|
|
50 |
|
|
|
51 |
return boxes |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def preprocess(self, im, allobj=None): |
|
|
55 |
""" |
|
|
56 |
Takes an image, return it as a numpy tensor that is readily |
|
|
57 |
to be fed into tfnet. If there is an accompanied annotation (allobj), |
|
|
58 |
meaning this preprocessing is serving the train process, then this |
|
|
59 |
image will be transformed with random noise to augment training data, |
|
|
60 |
using scale, translation, flipping and recolor. The accompanied |
|
|
61 |
parsed annotation (allobj) will also be modified accordingly. |
|
|
62 |
""" |
|
|
63 |
if type(im) is not np.ndarray: |
|
|
64 |
im = cv2.imread(im) |
|
|
65 |
|
|
|
66 |
if allobj is not None: # in training mode |
|
|
67 |
result = imcv2_affine_trans(im) |
|
|
68 |
im, dims, trans_param = result |
|
|
69 |
scale, offs, flip = trans_param |
|
|
70 |
for obj in allobj: |
|
|
71 |
_fix(obj, dims, scale, offs) |
|
|
72 |
if not flip: continue |
|
|
73 |
obj_1_ = obj[1] |
|
|
74 |
obj[1] = dims[0] - obj[3] |
|
|
75 |
obj[3] = dims[0] - obj_1_ |
|
|
76 |
im = imcv2_recolor(im) |
|
|
77 |
|
|
|
78 |
im = self.resize_input(im) |
|
|
79 |
if allobj is None: return im |
|
|
80 |
return im # , np.array(im) # for unit testing |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
def postprocess(self, net_out, im, save=True): |
|
|
84 |
""" |
|
|
85 |
Takes net output, draw predictions, save to disk |
|
|
86 |
""" |
|
|
87 |
meta, FLAGS = self.meta, self.FLAGS |
|
|
88 |
threshold = FLAGS.threshold |
|
|
89 |
colors, labels = meta['colors'], meta['labels'] |
|
|
90 |
|
|
|
91 |
boxes = self.findboxes(net_out) |
|
|
92 |
|
|
|
93 |
if type(im) is not np.ndarray: |
|
|
94 |
imgcv = cv2.imread(im) |
|
|
95 |
else: |
|
|
96 |
imgcv = im |
|
|
97 |
|
|
|
98 |
h, w, _ = imgcv.shape |
|
|
99 |
resultsForJSON = [] |
|
|
100 |
for b in boxes: |
|
|
101 |
boxResults = self.process_box(b, h, w, threshold) |
|
|
102 |
if boxResults is None: |
|
|
103 |
continue |
|
|
104 |
left, right, top, bot, mess, max_indx, confidence = boxResults |
|
|
105 |
thick = int((h + w) // 300) |
|
|
106 |
if self.FLAGS.json: |
|
|
107 |
resultsForJSON.append( |
|
|
108 |
{"label": mess, "confidence": float('%.2f' % confidence), "topleft": {"x": left, "y": top}, |
|
|
109 |
"bottomright": {"x": right, "y": bot}}) |
|
|
110 |
continue |
|
|
111 |
|
|
|
112 |
cv2.rectangle(imgcv, |
|
|
113 |
(left, top), (right, bot), |
|
|
114 |
self.meta['colors'][max_indx], thick) |
|
|
115 |
cv2.putText( |
|
|
116 |
imgcv, mess, (left, top - 12), |
|
|
117 |
0, 1e-3 * h, self.meta['colors'][max_indx], |
|
|
118 |
thick // 3) |
|
|
119 |
|
|
|
120 |
if not save: return imgcv |
|
|
121 |
|
|
|
122 |
outfolder = os.path.join(self.FLAGS.imgdir, 'out') |
|
|
123 |
img_name = os.path.join(outfolder, os.path.basename(im)) |
|
|
124 |
if self.FLAGS.json: |
|
|
125 |
textJSON = json.dumps(resultsForJSON) |
|
|
126 |
textFile = os.path.splitext(img_name)[0] + ".json" |
|
|
127 |
with open(textFile, 'w') as f: |
|
|
128 |
f.write(textJSON) |
|
|
129 |
return |
|
|
130 |
|
|
|
131 |
cv2.imwrite(img_name, imgcv) |