Diff of /darkflow/dark/darknet.py [000000] .. [d34869]

Switch to side-by-side view

--- a
+++ b/darkflow/dark/darknet.py
@@ -0,0 +1,90 @@
+from ..utils.process import cfg_yielder
+from .darkop import create_darkop
+from ..utils import loader
+import warnings
+import time
+import os
+
+
+class Darknet(object):
+    _EXT = '.weights'
+
+    def __init__(self, FLAGS):
+        self.get_weight_src(FLAGS)
+        self.modify = False
+
+        print('Parsing {}'.format(self.src_cfg))
+        src_parsed = self.parse_cfg(self.src_cfg, FLAGS)
+        self.src_meta, self.src_layers = src_parsed
+
+        if self.src_cfg == FLAGS.model:
+            self.meta, self.layers = src_parsed
+        else:
+            print('Parsing {}'.format(FLAGS.model))
+            des_parsed = self.parse_cfg(FLAGS.model, FLAGS)
+            self.meta, self.layers = des_parsed
+
+        self.load_weights()
+
+    def get_weight_src(self, FLAGS):
+        """
+        analyse FLAGS.load to know where is the 
+        source binary and what is its config.
+        can be: None, FLAGS.model, or some other
+        """
+        self.src_bin = FLAGS.model + self._EXT
+        self.src_bin = FLAGS.binary + self.src_bin
+        self.src_bin = os.path.abspath(self.src_bin)
+        exist = os.path.isfile(self.src_bin)
+
+        if FLAGS.load == str(): FLAGS.load = int()
+        if type(FLAGS.load) is int:
+            self.src_cfg = FLAGS.model
+            if FLAGS.load:
+                self.src_bin = None
+            elif not exist:
+                self.src_bin = None
+        else:
+            assert os.path.isfile(FLAGS.load), \
+                '{} not found'.format(FLAGS.load)
+            self.src_bin = FLAGS.load
+            name = loader.model_name(FLAGS.load)
+            cfg_path = os.path.join(FLAGS.config, name + '.cfg')
+            if not os.path.isfile(cfg_path):
+                warnings.warn(
+                    '{} not found, use {} instead'.format(
+                        cfg_path, FLAGS.model))
+                cfg_path = FLAGS.model
+            self.src_cfg = cfg_path
+            FLAGS.load = int()
+
+    def parse_cfg(self, model, FLAGS):
+        """
+        return a list of `layers` objects (darkop.py)
+        given path to binaries/ and configs/
+        """
+        args = [model, FLAGS.binary]
+        cfg_layers = cfg_yielder(*args)
+        meta = dict();
+        layers = list()
+        for i, info in enumerate(cfg_layers):
+            if i == 0:
+                meta = info; continue
+            else:
+                new = create_darkop(*info)
+            layers.append(new)
+        return meta, layers
+
+    def load_weights(self):
+        """
+        Use `layers` and Loader to load .weights file
+        """
+        print('Loading {} ...'.format(self.src_bin))
+        start = time.time()
+
+        args = [self.src_bin, self.src_layers]
+        wgts_loader = loader.create_loader(*args)
+        for layer in self.layers: layer.load(wgts_loader)
+
+        stop = time.time()
+        print('Finished in {}s'.format(stop - start))