Diff of /darkflow/net/flow.py [000000] .. [d34869]

Switch to unified view

a b/darkflow/net/flow.py
1
import os
2
import time
3
import numpy as np
4
import tensorflow as tf
5
import pickle
6
from multiprocessing.pool import ThreadPool
7
8
train_stats = (
9
    'Training statistics: \n'
10
    '\tLearning rate : {}\n'
11
    '\tBatch size    : {}\n'
12
    '\tEpoch number  : {}\n'
13
    '\tBackup every  : {}'
14
)
15
pool = ThreadPool()
16
17
18
def _save_ckpt(self, step, loss_profile):
19
    file = '{}-{}{}'
20
    model = self.meta['name']
21
22
    profile = file.format(model, step, '.profile')
23
    profile = os.path.join(self.FLAGS.backup, profile)
24
    with open(profile, 'wb') as profile_ckpt:
25
        pickle.dump(loss_profile, profile_ckpt)
26
27
    ckpt = file.format(model, step, '')
28
    ckpt = os.path.join(self.FLAGS.backup, ckpt)
29
    self.say('Checkpoint at step {}'.format(step))
30
    self.saver.save(self.sess, ckpt)
31
32
33
def train(self):
34
    loss_ph = self.framework.placeholders
35
    loss_mva = None;
36
    profile = list()
37
38
    batches = self.framework.shuffle()
39
    loss_op = self.framework.loss
40
41
    for i, (x_batch, datum) in enumerate(batches):
42
        if not i: self.say(train_stats.format(
43
            self.FLAGS.lr, self.FLAGS.batch,
44
            self.FLAGS.epoch, self.FLAGS.save
45
        ))
46
47
        feed_dict = {
48
            loss_ph[key]: datum[key]
49
            for key in loss_ph}
50
        feed_dict[self.inp] = x_batch
51
        feed_dict.update(self.feed)
52
53
        fetches = [self.train_op, loss_op]
54
55
        if self.FLAGS.summary:
56
            fetches.append(self.summary_op)
57
58
        fetched = self.sess.run(fetches, feed_dict)
59
        loss = fetched[1]
60
61
        if loss_mva is None: loss_mva = loss
62
        loss_mva = .9 * loss_mva + .1 * loss
63
        step_now = self.FLAGS.load + i + 1
64
65
        if self.FLAGS.summary:
66
            self.writer.add_summary(fetched[2], step_now)
67
68
        form = 'step {} - loss {} - moving ave loss {}'
69
        self.say(form.format(step_now, loss, loss_mva))
70
        profile += [(loss, loss_mva)]
71
72
        ckpt = (i + 1) % (self.FLAGS.save // self.FLAGS.batch)
73
        args = [step_now, profile]
74
        if not ckpt: _save_ckpt(self, *args)
75
76
    if ckpt: _save_ckpt(self, *args)
77
78
79
def return_predict(self, im):
80
    assert isinstance(im, np.ndarray), \
81
        'Image is not a np.ndarray'
82
    h, w, _ = im.shape
83
    im = self.framework.resize_input(im)
84
    this_inp = np.expand_dims(im, 0)
85
    feed_dict = {self.inp: this_inp}
86
87
    out = self.sess.run(self.out, feed_dict)[0]
88
    boxes = self.framework.findboxes(out)
89
    threshold = self.FLAGS.threshold
90
    boxesInfo = list()
91
    for box in boxes:
92
        tmpBox = self.framework.process_box(box, h, w, threshold)
93
        if tmpBox is None:
94
            continue
95
        boxesInfo.append({
96
            "label": tmpBox[4],
97
            "confidence": tmpBox[6],
98
            "topleft": {
99
                "x": tmpBox[0],
100
                "y": tmpBox[2]},
101
            "bottomright": {
102
                "x": tmpBox[1],
103
                "y": tmpBox[3]}
104
        })
105
    return boxesInfo
106
107
108
import math
109
110
111
def predict(self):
112
    inp_path = self.FLAGS.imgdir
113
    all_inps = os.listdir(inp_path)
114
    all_inps = [i for i in all_inps if self.framework.is_inp(i)]
115
    if not all_inps:
116
        msg = 'Failed to find any images in {} .'
117
        exit('Error: {}'.format(msg.format(inp_path)))
118
119
    batch = min(self.FLAGS.batch, len(all_inps))
120
121
    # predict in batches
122
    n_batch = int(math.ceil(len(all_inps) / batch))
123
    for j in range(n_batch):
124
        from_idx = j * batch
125
        to_idx = min(from_idx + batch, len(all_inps))
126
127
        # collect images input in the batch
128
        this_batch = all_inps[from_idx:to_idx]
129
        inp_feed = pool.map(lambda inp: (
130
            np.expand_dims(self.framework.preprocess(
131
                os.path.join(inp_path, inp)), 0)), this_batch)
132
133
        # Feed to the net
134
        feed_dict = {self.inp: np.concatenate(inp_feed, 0)}
135
        self.say('Forwarding {} inputs ...'.format(len(inp_feed)))
136
        start = time.time()
137
        out = self.sess.run(self.out, feed_dict)
138
        stop = time.time();
139
        last = stop - start
140
        self.say('Total time = {}s / {} inps = {} ips'.format(
141
            last, len(inp_feed), len(inp_feed) / last))
142
143
        # Post processing
144
        self.say('Post processing {} inputs ...'.format(len(inp_feed)))
145
        start = time.time()
146
        pool.map(lambda p: (lambda i, prediction:
147
                            self.framework.postprocess(
148
                                prediction, os.path.join(inp_path, this_batch[i])))(*p),
149
                 enumerate(out))
150
        stop = time.time();
151
        last = stop - start
152
153
        # Timing
154
        self.say('Total time = {}s / {} inps = {} ips'.format(
155
            last, len(inp_feed), len(inp_feed) / last))