Diff of /test.py [000000] .. [d34869]

Switch to unified view

a b/test.py
1
import os
2
import cv2
3
import time
4
from utils import iou
5
from scipy import spatial
6
from darkflow.net.build import TFNet
7
8
options = {'model': 'cfg/tiny-yolo-voc-3c.cfg',
9
           'load': 3750,
10
           'threshold': 0.1,
11
           'gpu': 0.7}
12
13
tfnet = TFNet(options)
14
15
avg_time = 0
16
pred_bb = []  # predicted bounding box
17
pred_cls = []  # predicted class
18
pred_conf = []  # predicted class confidence
19
20
directory = 'dataset/Testing/Images/'
21
22
for file_name in os.listdir(directory):
23
    tic = time.time()
24
    image = cv2.imread(directory + file_name)
25
    output = tfnet.return_predict(image)
26
27
    rbc = 0
28
    wbc = 0
29
    platelets = 0
30
31
    cell = []
32
    cls = []
33
    conf = []
34
35
    record = []
36
    tl_ = []
37
    br_ = []
38
    iou_ = []
39
    iou_value = 0
40
41
    for prediction in output:
42
        label = prediction['label']
43
        confidence = prediction['confidence']
44
45
        tl = (prediction['topleft']['x'], prediction['topleft']['y'])
46
        br = (prediction['bottomright']['x'], prediction['bottomright']['y'])
47
48
        if label == 'RBC' and confidence < .5:
49
            continue
50
        if label == 'WBC' and confidence < .25:
51
            continue
52
        if label == 'Platelets' and confidence < .25:
53
            continue
54
55
        # clearing up spurious platelets
56
        if label == 'Platelets':
57
            if record:
58
                tree = spatial.cKDTree(record)
59
                index = tree.query(tl)[1]
60
                iou_value = iou(tl + br, tl_[index] + br_[index])
61
                iou_.append(iou_value)
62
63
            if iou_value > 0.1:
64
                continue
65
66
            record.append(tl)
67
            tl_.append(tl)
68
            br_.append(br)
69
70
        # image = cv2.rectangle(image, tl, br,color, 2)
71
        center_x = int((tl[0] + br[0]) / 2)
72
        center_y = int((tl[1] + br[1]) / 2)
73
        center = (center_x, center_y)
74
75
        if label == 'RBC':
76
            color = (255, 0, 0)
77
            rbc = rbc + 1
78
        if label == 'WBC':
79
            color = (0, 255, 0)
80
            wbc = wbc + 1
81
        if label == 'Platelets':
82
            color = (0, 0, 255)
83
            platelets = platelets + 1
84
85
        radius = int((br[0] - tl[0]) / 2)
86
        image = cv2.circle(image, center, radius, color, 2)
87
        font = cv2.FONT_HERSHEY_COMPLEX
88
        image = cv2.putText(image, label, (center_x - 15, center_y + 5), font, .5, color, 1)
89
        cell.append([tl[0], tl[1], br[0], br[1]])
90
91
        if label == 'RBC':
92
            cls.append(0)
93
        if label == 'WBC':
94
            cls.append(1)
95
        if label == 'Platelets':
96
            cls.append(2)
97
98
        conf.append(confidence)
99
100
    pred_bb.append(cell)
101
    pred_cls.append(cls)
102
    pred_conf.append(conf)
103
    cv2.imwrite('output/' + file_name, image)
104
    toc = time.time()
105
    avg_time = avg_time + (toc - tic) * 1000
106
107
avg_time = avg_time / 60
108
109
print('Mean time: {0:.5}'.format(avg_time), 'ms')
110
print('All Done!')