Diff of /darkflow/net/build.py [000000] .. [d34869]

Switch to unified view

a b/darkflow/net/build.py
1
import tensorflow as tf
2
import time
3
from . import help
4
from . import flow
5
from .ops import op_create, identity
6
from .ops import HEADER, LINE
7
from .framework import create_framework
8
from ..dark.darknet import Darknet
9
import json
10
import os
11
12
tf = tf.compat.v1
13
tf.disable_v2_behavior()
14
15
16
class TFNet(object):
17
    _TRAINER = dict({
18
        'rmsprop': tf.train.RMSPropOptimizer,
19
        'adadelta': tf.train.AdadeltaOptimizer,
20
        'adagrad': tf.train.AdagradOptimizer,
21
        'adagradDA': tf.train.AdagradDAOptimizer,
22
        'momentum': tf.train.MomentumOptimizer,
23
        'adam': tf.train.AdamOptimizer,
24
        'ftrl': tf.train.FtrlOptimizer,
25
        'sgd': tf.train.GradientDescentOptimizer
26
    })
27
28
    # imported methods
29
    _get_fps = help._get_fps
30
    say = help.say
31
    train = flow.train
32
    camera = help.camera
33
    predict = flow.predict
34
    return_predict = flow.return_predict
35
    to_darknet = help.to_darknet
36
    build_train_op = help.build_train_op
37
    load_from_ckpt = help.load_from_ckpt
38
39
    def __init__(self, FLAGS, darknet=None):
40
        self.ntrain = 0
41
42
        if isinstance(FLAGS, dict):
43
            from ..defaults import argHandler
44
            newFLAGS = argHandler()
45
            newFLAGS.setDefaults()
46
            newFLAGS.update(FLAGS)
47
            FLAGS = newFLAGS
48
49
        self.FLAGS = FLAGS
50
        if self.FLAGS.pbLoad and self.FLAGS.metaLoad:
51
            self.say('\nLoading from .pb and .meta')
52
            self.graph = tf.Graph()
53
            device_name = FLAGS.gpuName \
54
                if FLAGS.gpu > 0.0 else None
55
            with tf.device(device_name):
56
                with self.graph.as_default() as g:
57
                    self.build_from_pb()
58
            return
59
60
        if darknet is None:
61
            darknet = Darknet(FLAGS)
62
            self.ntrain = len(darknet.layers)
63
64
        self.darknet = darknet
65
        args = [darknet.meta, FLAGS]
66
        self.num_layer = len(darknet.layers)
67
        self.framework = create_framework(*args)
68
69
        self.meta = darknet.meta
70
71
        self.say('\nBuilding net ...')
72
        start = time.time()
73
        self.graph = tf.Graph()
74
        device_name = FLAGS.gpuName \
75
            if FLAGS.gpu > 0.0 else None
76
        with tf.device(device_name):
77
            with self.graph.as_default() as g:
78
                self.build_forward()
79
                self.setup_meta_ops()
80
        self.say('Finished in {}s\n'.format(
81
            time.time() - start))
82
83
    def build_from_pb(self):
84
        with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
85
            graph_def = tf.GraphDef()
86
            graph_def.ParseFromString(f.read())
87
88
        tf.import_graph_def(
89
            graph_def,
90
            name=""
91
        )
92
        with open(self.FLAGS.metaLoad, 'r') as fp:
93
            self.meta = json.load(fp)
94
        self.framework = create_framework(self.meta, self.FLAGS)
95
96
        # Placeholders
97
        self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
98
        self.feed = dict()  # other placeholders
99
        self.out = tf.get_default_graph().get_tensor_by_name('output:0')
100
101
        self.setup_meta_ops()
102
103
    def build_forward(self):
104
        verbalise = self.FLAGS.verbalise
105
106
        # Placeholders
107
        inp_size = [None] + self.meta['inp_size']
108
        self.inp = tf.placeholder(tf.float32, inp_size, 'input')
109
        self.feed = dict()  # other placeholders
110
111
        # Build the forward pass
112
        state = identity(self.inp)
113
        roof = self.num_layer - self.ntrain
114
        self.say(HEADER, LINE)
115
        for i, layer in enumerate(self.darknet.layers):
116
            scope = '{}-{}'.format(str(i), layer.type)
117
            args = [layer, state, i, roof, self.feed]
118
            state = op_create(*args)
119
            mess = state.verbalise()
120
            self.say(mess)
121
        self.say(LINE)
122
123
        self.top = state
124
        self.out = tf.identity(state.out, name='output')
125
126
    def setup_meta_ops(self):
127
        cfg = dict({
128
            'allow_soft_placement': False,
129
            'log_device_placement': False
130
        })
131
132
        utility = min(self.FLAGS.gpu, 1.)
133
        if utility > 0.0:
134
            self.say('GPU mode with {} usage'.format(utility))
135
            cfg['gpu_options'] = tf.GPUOptions(
136
                per_process_gpu_memory_fraction=utility)
137
            cfg['allow_soft_placement'] = True
138
        else:
139
            self.say('Running entirely on CPU')
140
            cfg['device_count'] = {'GPU': 0}
141
142
        if self.FLAGS.train: self.build_train_op()
143
144
        if self.FLAGS.summary:
145
            self.summary_op = tf.summary.merge_all()
146
            self.writer = tf.summary.FileWriter(self.FLAGS.summary + 'train')
147
148
        self.sess = tf.Session(config=tf.ConfigProto(**cfg))
149
        self.sess.run(tf.global_variables_initializer())
150
151
        if not self.ntrain: return
152
        self.saver = tf.train.Saver(tf.global_variables(),
153
                                    max_to_keep=self.FLAGS.keep)
154
        if self.FLAGS.load != 0: self.load_from_ckpt()
155
156
        if self.FLAGS.summary:
157
            self.writer.add_graph(self.sess.graph)
158
159
    def savepb(self):
160
        """
161
        Create a standalone const graph def that
162
        C++ can load and run.
163
        """
164
        darknet_pb = self.to_darknet()
165
        flags_pb = self.FLAGS
166
        flags_pb.verbalise = False
167
168
        flags_pb.train = False
169
        # rebuild another tfnet. all const.
170
        tfnet_pb = TFNet(flags_pb, darknet_pb)
171
        tfnet_pb.sess = tf.Session(graph=tfnet_pb.graph)
172
        # tfnet_pb.predict() # uncomment for unit testing
173
        name = 'built_graph/{}.pb'.format(self.meta['name'])
174
        os.makedirs(os.path.dirname(name), exist_ok=True)
175
        # Save dump of everything in meta
176
        with open('built_graph/{}.meta'.format(self.meta['name']), 'w') as fp:
177
            json.dump(self.meta, fp)
178
        self.say('Saving const graph def to {}'.format(name))
179
        graph_def = tfnet_pb.sess.graph_def
180
        tf.train.write_graph(graph_def, './', name, False)