Diff of /detect.py [000000] .. [d34869]

Switch to unified view

a b/detect.py
1
import cv2
2
import time
3
from utils import iou
4
from scipy import spatial
5
from darkflow.net.build import TFNet
6
7
options = {'model': 'cfg/tiny-yolo-voc-3c.cfg',
8
           'load': 3750,
9
           'threshold': 0.1,
10
           'gpu': 0.7}
11
12
tfnet = TFNet(options)
13
14
pred_bb = []  # predicted bounding box
15
pred_cls = []  # predicted class
16
pred_conf = []  # predicted class confidence
17
18
19
def blood_cell_count(file_name):
20
    rbc = 0
21
    wbc = 0
22
    platelets = 0
23
24
    cell = []
25
    cls = []
26
    conf = []
27
28
    record = []
29
    tl_ = []
30
    br_ = []
31
    iou_ = []
32
    iou_value = 0
33
34
    tic = time.time()
35
    image = cv2.imread('data/' + file_name)
36
    output = tfnet.return_predict(image)
37
38
    for prediction in output:
39
        label = prediction['label']
40
        confidence = prediction['confidence']
41
        tl = (prediction['topleft']['x'], prediction['topleft']['y'])
42
        br = (prediction['bottomright']['x'], prediction['bottomright']['y'])
43
44
        if label == 'RBC' and confidence < .5:
45
            continue
46
        if label == 'WBC' and confidence < .25:
47
            continue
48
        if label == 'Platelets' and confidence < .25:
49
            continue
50
51
        # clearing up overlapped same platelets
52
        if label == 'Platelets':
53
            if record:
54
                tree = spatial.cKDTree(record)
55
                index = tree.query(tl)[1]
56
                iou_value = iou(tl + br, tl_[index] + br_[index])
57
                iou_.append(iou_value)
58
59
            if iou_value > 0.1:
60
                continue
61
62
            record.append(tl)
63
            tl_.append(tl)
64
            br_.append(br)
65
66
        center_x = int((tl[0] + br[0]) / 2)
67
        center_y = int((tl[1] + br[1]) / 2)
68
        center = (center_x, center_y)
69
70
        if label == 'RBC':
71
            color = (255, 0, 0)
72
            rbc = rbc + 1
73
        if label == 'WBC':
74
            color = (0, 255, 0)
75
            wbc = wbc + 1
76
        if label == 'Platelets':
77
            color = (0, 0, 255)
78
            platelets = platelets + 1
79
80
        radius = int((br[0] - tl[0]) / 2)
81
        image = cv2.circle(image, center, radius, color, 2)
82
        font = cv2.FONT_HERSHEY_COMPLEX
83
        image = cv2.putText(image, label, (center_x - 15, center_y + 5), font, .5, color, 1)
84
        cell.append([tl[0], tl[1], br[0], br[1]])
85
86
        if label == 'RBC':
87
            cls.append(0)
88
        if label == 'WBC':
89
            cls.append(1)
90
        if label == 'Platelets':
91
            cls.append(2)
92
93
        conf.append(confidence)
94
95
    toc = time.time()
96
    pred_bb.append(cell)
97
    pred_cls.append(cls)
98
    pred_conf.append(conf)
99
    avg_time = (toc - tic) * 1000
100
    print('{0:.5}'.format(avg_time), 'ms')
101
102
    cv2.imwrite('output/' + file_name, image)
103
    cv2.imshow('Total RBC: ' + str(rbc) + ', WBC: ' + str(wbc) + ', Platelets: ' + str(platelets), image)
104
    print('Press "ESC" to close . . .')
105
    if cv2.waitKey(0) & 0xff == 27:
106
        cv2.destroyAllWindows()
107
108
109
image_name = 'image_001.jpg'
110
blood_cell_count(image_name)
111
112
print('All Done!')