Diff of /darkflow/net/help.py [000000] .. [d34869]

Switch to unified view

a b/darkflow/net/help.py
1
"""
2
tfnet secondary (helper) methods
3
"""
4
from ..utils.loader import create_loader
5
from time import time as timer
6
import tensorflow as tf
7
import numpy as np
8
import sys
9
import cv2
10
import os
11
12
old_graph_msg = 'Resolving old graph def {} (no guarantee)'
13
14
15
def build_train_op(self):
16
    self.framework.loss(self.out)
17
    self.say('Building {} train op'.format(self.meta['model']))
18
    optimizer = self._TRAINER[self.FLAGS.trainer](self.FLAGS.lr)
19
    gradients = optimizer.compute_gradients(self.framework.loss)
20
    self.train_op = optimizer.apply_gradients(gradients)
21
22
23
def load_from_ckpt(self):
24
    if self.FLAGS.load < 0:  # load lastest ckpt
25
        with open(os.path.join(self.FLAGS.backup, 'checkpoint'), 'r') as f:
26
            last = f.readlines()[-1].strip()
27
            load_point = last.split(' ')[1]
28
            load_point = load_point.split('"')[1]
29
            load_point = load_point.split('-')[-1]
30
            self.FLAGS.load = int(load_point)
31
32
    load_point = os.path.join(self.FLAGS.backup, self.meta['name'])
33
    load_point = '{}-{}'.format(load_point, self.FLAGS.load)
34
    self.say('Loading from {}'.format(load_point))
35
    try:
36
        self.saver.restore(self.sess, load_point)
37
    except:
38
        load_old_graph(self, load_point)
39
40
41
def say(self, *msgs):
42
    if not self.FLAGS.verbalise:
43
        return
44
    msgs = list(msgs)
45
    for msg in msgs:
46
        if msg is None: continue
47
        print(msg)
48
49
50
def load_old_graph(self, ckpt):
51
    ckpt_loader = create_loader(ckpt)
52
    self.say(old_graph_msg.format(ckpt))
53
54
    for var in tf.global_variables():
55
        name = var.name.split(':')[0]
56
        args = [name, var.get_shape()]
57
        val = ckpt_loader(args)
58
        assert val is not None, \
59
            'Cannot find and load {}'.format(var.name)
60
        shp = val.shape
61
        plh = tf.placeholder(tf.float32, shp)
62
        op = tf.assign(var, plh)
63
        self.sess.run(op, {plh: val})
64
65
66
def _get_fps(self, frame):
67
    elapsed = int()
68
    start = timer()
69
    preprocessed = self.framework.preprocess(frame)
70
    feed_dict = {self.inp: [preprocessed]}
71
    net_out = self.sess.run(self.out, feed_dict)[0]
72
    processed = self.framework.postprocess(net_out, frame, False)
73
    return timer() - start
74
75
76
def camera(self):
77
    file = self.FLAGS.demo
78
    SaveVideo = self.FLAGS.saveVideo
79
80
    if file == 'camera':
81
        file = 0
82
    else:
83
        assert os.path.isfile(file), \
84
            'file {} does not exist'.format(file)
85
86
    camera = cv2.VideoCapture(file)
87
88
    if file == 0:
89
        self.say('Press [ESC] to quit demo')
90
91
    assert camera.isOpened(), \
92
        'Cannot capture source'
93
94
    if file == 0:  # camera window
95
        cv2.namedWindow('', 0)
96
        _, frame = camera.read()
97
        height, width, _ = frame.shape
98
        cv2.resizeWindow('', width, height)
99
    else:
100
        _, frame = camera.read()
101
        height, width, _ = frame.shape
102
103
    if SaveVideo:
104
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
105
        if file == 0:  # camera window
106
            fps = 1 / self._get_fps(frame)
107
            if fps < 1:
108
                fps = 1
109
        else:
110
            fps = round(camera.get(cv2.CAP_PROP_FPS))
111
        videoWriter = cv2.VideoWriter(
112
            'video.avi', fourcc, fps, (width, height))
113
114
    # buffers for demo in batch
115
    buffer_inp = list()
116
    buffer_pre = list()
117
118
    elapsed = int()
119
    start = timer()
120
    self.say('Press [ESC] to quit demo')
121
    # Loop through frames
122
    while camera.isOpened():
123
        elapsed += 1
124
        _, frame = camera.read()
125
        if frame is None:
126
            print('\nEnd of Video')
127
            break
128
        preprocessed = self.framework.preprocess(frame)
129
        buffer_inp.append(frame)
130
        buffer_pre.append(preprocessed)
131
132
        # Only process and imshow when queue is full
133
        if elapsed % self.FLAGS.queue == 0:
134
            feed_dict = {self.inp: buffer_pre}
135
            net_out = self.sess.run(self.out, feed_dict)
136
            for img, single_out in zip(buffer_inp, net_out):
137
                postprocessed = self.framework.postprocess(
138
                    single_out, img, False)
139
                if SaveVideo:
140
                    videoWriter.write(postprocessed)
141
                if file == 0:  # camera window
142
                    cv2.imshow('', postprocessed)
143
            # Clear Buffers
144
            buffer_inp = list()
145
            buffer_pre = list()
146
147
        if elapsed % 5 == 0:
148
            sys.stdout.write('\r')
149
            sys.stdout.write('{0:3.3f} FPS'.format(
150
                elapsed / (timer() - start)))
151
            sys.stdout.flush()
152
        if file == 0:  # camera window
153
            choice = cv2.waitKey(1)
154
            if choice == 27: break
155
156
    sys.stdout.write('\n')
157
    if SaveVideo:
158
        videoWriter.release()
159
    camera.release()
160
    if file == 0:  # camera window
161
        cv2.destroyAllWindows()
162
163
164
def to_darknet(self):
165
    darknet_ckpt = self.darknet
166
167
    with self.graph.as_default() as g:
168
        for var in tf.global_variables():
169
            name = var.name.split(':')[0]
170
            var_name = name.split('-')
171
            l_idx = int(var_name[0])
172
            w_sig = var_name[1].split('/')[-1]
173
            l = darknet_ckpt.layers[l_idx]
174
            l.w[w_sig] = var.eval(self.sess)
175
176
    for layer in darknet_ckpt.layers:
177
        for ph in layer.h:
178
            layer.h[ph] = None
179
180
    return darknet_ckpt