--- a +++ b/darkflow/net/help.py @@ -0,0 +1,180 @@ +""" +tfnet secondary (helper) methods +""" +from ..utils.loader import create_loader +from time import time as timer +import tensorflow as tf +import numpy as np +import sys +import cv2 +import os + +old_graph_msg = 'Resolving old graph def {} (no guarantee)' + + +def build_train_op(self): + self.framework.loss(self.out) + self.say('Building {} train op'.format(self.meta['model'])) + optimizer = self._TRAINER[self.FLAGS.trainer](self.FLAGS.lr) + gradients = optimizer.compute_gradients(self.framework.loss) + self.train_op = optimizer.apply_gradients(gradients) + + +def load_from_ckpt(self): + if self.FLAGS.load < 0: # load lastest ckpt + with open(os.path.join(self.FLAGS.backup, 'checkpoint'), 'r') as f: + last = f.readlines()[-1].strip() + load_point = last.split(' ')[1] + load_point = load_point.split('"')[1] + load_point = load_point.split('-')[-1] + self.FLAGS.load = int(load_point) + + load_point = os.path.join(self.FLAGS.backup, self.meta['name']) + load_point = '{}-{}'.format(load_point, self.FLAGS.load) + self.say('Loading from {}'.format(load_point)) + try: + self.saver.restore(self.sess, load_point) + except: + load_old_graph(self, load_point) + + +def say(self, *msgs): + if not self.FLAGS.verbalise: + return + msgs = list(msgs) + for msg in msgs: + if msg is None: continue + print(msg) + + +def load_old_graph(self, ckpt): + ckpt_loader = create_loader(ckpt) + self.say(old_graph_msg.format(ckpt)) + + for var in tf.global_variables(): + name = var.name.split(':')[0] + args = [name, var.get_shape()] + val = ckpt_loader(args) + assert val is not None, \ + 'Cannot find and load {}'.format(var.name) + shp = val.shape + plh = tf.placeholder(tf.float32, shp) + op = tf.assign(var, plh) + self.sess.run(op, {plh: val}) + + +def _get_fps(self, frame): + elapsed = int() + start = timer() + preprocessed = self.framework.preprocess(frame) + feed_dict = {self.inp: [preprocessed]} + net_out = self.sess.run(self.out, feed_dict)[0] + processed = self.framework.postprocess(net_out, frame, False) + return timer() - start + + +def camera(self): + file = self.FLAGS.demo + SaveVideo = self.FLAGS.saveVideo + + if file == 'camera': + file = 0 + else: + assert os.path.isfile(file), \ + 'file {} does not exist'.format(file) + + camera = cv2.VideoCapture(file) + + if file == 0: + self.say('Press [ESC] to quit demo') + + assert camera.isOpened(), \ + 'Cannot capture source' + + if file == 0: # camera window + cv2.namedWindow('', 0) + _, frame = camera.read() + height, width, _ = frame.shape + cv2.resizeWindow('', width, height) + else: + _, frame = camera.read() + height, width, _ = frame.shape + + if SaveVideo: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + if file == 0: # camera window + fps = 1 / self._get_fps(frame) + if fps < 1: + fps = 1 + else: + fps = round(camera.get(cv2.CAP_PROP_FPS)) + videoWriter = cv2.VideoWriter( + 'video.avi', fourcc, fps, (width, height)) + + # buffers for demo in batch + buffer_inp = list() + buffer_pre = list() + + elapsed = int() + start = timer() + self.say('Press [ESC] to quit demo') + # Loop through frames + while camera.isOpened(): + elapsed += 1 + _, frame = camera.read() + if frame is None: + print('\nEnd of Video') + break + preprocessed = self.framework.preprocess(frame) + buffer_inp.append(frame) + buffer_pre.append(preprocessed) + + # Only process and imshow when queue is full + if elapsed % self.FLAGS.queue == 0: + feed_dict = {self.inp: buffer_pre} + net_out = self.sess.run(self.out, feed_dict) + for img, single_out in zip(buffer_inp, net_out): + postprocessed = self.framework.postprocess( + single_out, img, False) + if SaveVideo: + videoWriter.write(postprocessed) + if file == 0: # camera window + cv2.imshow('', postprocessed) + # Clear Buffers + buffer_inp = list() + buffer_pre = list() + + if elapsed % 5 == 0: + sys.stdout.write('\r') + sys.stdout.write('{0:3.3f} FPS'.format( + elapsed / (timer() - start))) + sys.stdout.flush() + if file == 0: # camera window + choice = cv2.waitKey(1) + if choice == 27: break + + sys.stdout.write('\n') + if SaveVideo: + videoWriter.release() + camera.release() + if file == 0: # camera window + cv2.destroyAllWindows() + + +def to_darknet(self): + darknet_ckpt = self.darknet + + with self.graph.as_default() as g: + for var in tf.global_variables(): + name = var.name.split(':')[0] + var_name = name.split('-') + l_idx = int(var_name[0]) + w_sig = var_name[1].split('/')[-1] + l = darknet_ckpt.layers[l_idx] + l.w[w_sig] = var.eval(self.sess) + + for layer in darknet_ckpt.layers: + for ph in layer.h: + layer.h[ph] = None + + return darknet_ckpt