a b/darkflow/net/yolo/predict.py
1
from ...utils.im_transform import imcv2_recolor, imcv2_affine_trans
2
from ...utils.box import BoundBox, box_iou, prob_compare
3
import numpy as np
4
import cv2
5
import os
6
import json
7
from ...cython_utils.cy_yolo_findboxes import yolo_box_constructor
8
9
10
def _fix(obj, dims, scale, offs):
11
    for i in range(1, 5):
12
        dim = dims[(i + 1) % 2]
13
        off = offs[(i + 1) % 2]
14
        obj[i] = int(obj[i] * scale - off)
15
        obj[i] = max(min(obj[i], dim), 0)
16
17
18
def resize_input(self, im):
19
    h, w, c = self.meta['inp_size']
20
    imsz = cv2.resize(im, (w, h))
21
    imsz = imsz / 255.
22
    imsz = imsz[:, :, ::-1]
23
    return imsz
24
25
26
def process_box(self, b, h, w, threshold):
27
    max_indx = np.argmax(b.probs)
28
    max_prob = b.probs[max_indx]
29
    label = self.meta['labels'][max_indx]
30
    if max_prob > threshold:
31
        left = int((b.x - b.w / 2.) * w)
32
        right = int((b.x + b.w / 2.) * w)
33
        top = int((b.y - b.h / 2.) * h)
34
        bot = int((b.y + b.h / 2.) * h)
35
        if left < 0:  left = 0
36
        if right > w - 1: right = w - 1
37
        if top < 0:   top = 0
38
        if bot > h - 1:   bot = h - 1
39
        mess = '{}'.format(label)
40
        return (left, right, top, bot, mess, max_indx, max_prob)
41
    return None
42
43
44
def findboxes(self, net_out):
45
    meta, FLAGS = self.meta, self.FLAGS
46
    threshold = FLAGS.threshold
47
48
    boxes = []
49
    boxes = yolo_box_constructor(meta, net_out, threshold)
50
51
    return boxes
52
53
54
def preprocess(self, im, allobj=None):
55
    """
56
    Takes an image, return it as a numpy tensor that is readily
57
    to be fed into tfnet. If there is an accompanied annotation (allobj),
58
    meaning this preprocessing is serving the train process, then this
59
    image will be transformed with random noise to augment training data,
60
    using scale, translation, flipping and recolor. The accompanied
61
    parsed annotation (allobj) will also be modified accordingly.
62
    """
63
    if type(im) is not np.ndarray:
64
        im = cv2.imread(im)
65
66
    if allobj is not None:  # in training mode
67
        result = imcv2_affine_trans(im)
68
        im, dims, trans_param = result
69
        scale, offs, flip = trans_param
70
        for obj in allobj:
71
            _fix(obj, dims, scale, offs)
72
            if not flip: continue
73
            obj_1_ = obj[1]
74
            obj[1] = dims[0] - obj[3]
75
            obj[3] = dims[0] - obj_1_
76
        im = imcv2_recolor(im)
77
78
    im = self.resize_input(im)
79
    if allobj is None: return im
80
    return im  # , np.array(im) # for unit testing
81
82
83
def postprocess(self, net_out, im, save=True):
84
    """
85
    Takes net output, draw predictions, save to disk
86
    """
87
    meta, FLAGS = self.meta, self.FLAGS
88
    threshold = FLAGS.threshold
89
    colors, labels = meta['colors'], meta['labels']
90
91
    boxes = self.findboxes(net_out)
92
93
    if type(im) is not np.ndarray:
94
        imgcv = cv2.imread(im)
95
    else:
96
        imgcv = im
97
98
    h, w, _ = imgcv.shape
99
    resultsForJSON = []
100
    for b in boxes:
101
        boxResults = self.process_box(b, h, w, threshold)
102
        if boxResults is None:
103
            continue
104
        left, right, top, bot, mess, max_indx, confidence = boxResults
105
        thick = int((h + w) // 300)
106
        if self.FLAGS.json:
107
            resultsForJSON.append(
108
                {"label": mess, "confidence": float('%.2f' % confidence), "topleft": {"x": left, "y": top},
109
                 "bottomright": {"x": right, "y": bot}})
110
            continue
111
112
        cv2.rectangle(imgcv,
113
                      (left, top), (right, bot),
114
                      self.meta['colors'][max_indx], thick)
115
        cv2.putText(
116
            imgcv, mess, (left, top - 12),
117
            0, 1e-3 * h, self.meta['colors'][max_indx],
118
               thick // 3)
119
120
    if not save: return imgcv
121
122
    outfolder = os.path.join(self.FLAGS.imgdir, 'out')
123
    img_name = os.path.join(outfolder, os.path.basename(im))
124
    if self.FLAGS.json:
125
        textJSON = json.dumps(resultsForJSON)
126
        textFile = os.path.splitext(img_name)[0] + ".json"
127
        with open(textFile, 'w') as f:
128
            f.write(textJSON)
129
        return
130
131
    cv2.imwrite(img_name, imgcv)