--- a +++ b/darkflow/net/flow.py @@ -0,0 +1,155 @@ +import os +import time +import numpy as np +import tensorflow as tf +import pickle +from multiprocessing.pool import ThreadPool + +train_stats = ( + 'Training statistics: \n' + '\tLearning rate : {}\n' + '\tBatch size : {}\n' + '\tEpoch number : {}\n' + '\tBackup every : {}' +) +pool = ThreadPool() + + +def _save_ckpt(self, step, loss_profile): + file = '{}-{}{}' + model = self.meta['name'] + + profile = file.format(model, step, '.profile') + profile = os.path.join(self.FLAGS.backup, profile) + with open(profile, 'wb') as profile_ckpt: + pickle.dump(loss_profile, profile_ckpt) + + ckpt = file.format(model, step, '') + ckpt = os.path.join(self.FLAGS.backup, ckpt) + self.say('Checkpoint at step {}'.format(step)) + self.saver.save(self.sess, ckpt) + + +def train(self): + loss_ph = self.framework.placeholders + loss_mva = None; + profile = list() + + batches = self.framework.shuffle() + loss_op = self.framework.loss + + for i, (x_batch, datum) in enumerate(batches): + if not i: self.say(train_stats.format( + self.FLAGS.lr, self.FLAGS.batch, + self.FLAGS.epoch, self.FLAGS.save + )) + + feed_dict = { + loss_ph[key]: datum[key] + for key in loss_ph} + feed_dict[self.inp] = x_batch + feed_dict.update(self.feed) + + fetches = [self.train_op, loss_op] + + if self.FLAGS.summary: + fetches.append(self.summary_op) + + fetched = self.sess.run(fetches, feed_dict) + loss = fetched[1] + + if loss_mva is None: loss_mva = loss + loss_mva = .9 * loss_mva + .1 * loss + step_now = self.FLAGS.load + i + 1 + + if self.FLAGS.summary: + self.writer.add_summary(fetched[2], step_now) + + form = 'step {} - loss {} - moving ave loss {}' + self.say(form.format(step_now, loss, loss_mva)) + profile += [(loss, loss_mva)] + + ckpt = (i + 1) % (self.FLAGS.save // self.FLAGS.batch) + args = [step_now, profile] + if not ckpt: _save_ckpt(self, *args) + + if ckpt: _save_ckpt(self, *args) + + +def return_predict(self, im): + assert isinstance(im, np.ndarray), \ + 'Image is not a np.ndarray' + h, w, _ = im.shape + im = self.framework.resize_input(im) + this_inp = np.expand_dims(im, 0) + feed_dict = {self.inp: this_inp} + + out = self.sess.run(self.out, feed_dict)[0] + boxes = self.framework.findboxes(out) + threshold = self.FLAGS.threshold + boxesInfo = list() + for box in boxes: + tmpBox = self.framework.process_box(box, h, w, threshold) + if tmpBox is None: + continue + boxesInfo.append({ + "label": tmpBox[4], + "confidence": tmpBox[6], + "topleft": { + "x": tmpBox[0], + "y": tmpBox[2]}, + "bottomright": { + "x": tmpBox[1], + "y": tmpBox[3]} + }) + return boxesInfo + + +import math + + +def predict(self): + inp_path = self.FLAGS.imgdir + all_inps = os.listdir(inp_path) + all_inps = [i for i in all_inps if self.framework.is_inp(i)] + if not all_inps: + msg = 'Failed to find any images in {} .' + exit('Error: {}'.format(msg.format(inp_path))) + + batch = min(self.FLAGS.batch, len(all_inps)) + + # predict in batches + n_batch = int(math.ceil(len(all_inps) / batch)) + for j in range(n_batch): + from_idx = j * batch + to_idx = min(from_idx + batch, len(all_inps)) + + # collect images input in the batch + this_batch = all_inps[from_idx:to_idx] + inp_feed = pool.map(lambda inp: ( + np.expand_dims(self.framework.preprocess( + os.path.join(inp_path, inp)), 0)), this_batch) + + # Feed to the net + feed_dict = {self.inp: np.concatenate(inp_feed, 0)} + self.say('Forwarding {} inputs ...'.format(len(inp_feed))) + start = time.time() + out = self.sess.run(self.out, feed_dict) + stop = time.time(); + last = stop - start + self.say('Total time = {}s / {} inps = {} ips'.format( + last, len(inp_feed), len(inp_feed) / last)) + + # Post processing + self.say('Post processing {} inputs ...'.format(len(inp_feed))) + start = time.time() + pool.map(lambda p: (lambda i, prediction: + self.framework.postprocess( + prediction, os.path.join(inp_path, this_batch[i])))(*p), + enumerate(out)) + stop = time.time(); + last = stop - start + + # Timing + self.say('Total time = {}s / {} inps = {} ips'.format( + last, len(inp_feed), len(inp_feed) / last))