Diff of /darkflow/utils/loader.py [000000] .. [d34869]

Switch to side-by-side view

--- a
+++ b/darkflow/utils/loader.py
@@ -0,0 +1,164 @@
+import tensorflow as tf
+import os
+from .. import dark
+import numpy as np
+from os.path import basename
+
+tf = tf.compat.v1
+tf.disable_v2_behavior()
+
+
+class loader(object):
+    """
+    interface to work with both .weights and .ckpt files
+    in loading / recollecting / resolving mode
+    """
+    VAR_LAYER = ['convolutional', 'connected', 'local',
+                 'select', 'conv-select',
+                 'extract', 'conv-extract']
+
+    def __init__(self, *args):
+        self.src_key = list()
+        self.vals = list()
+        self.load(*args)
+
+    def __call__(self, key):
+        for idx in range(len(key)):
+            val = self.find(key, idx)
+            if val is not None: return val
+        return None
+
+    def find(self, key, idx):
+        up_to = min(len(self.src_key), 4)
+        for i in range(up_to):
+            key_b = self.src_key[i]
+            if key_b[idx:] == key[idx:]:
+                return self.yields(i)
+        return None
+
+    def yields(self, idx):
+        del self.src_key[idx]
+        temp = self.vals[idx]
+        del self.vals[idx]
+        return temp
+
+
+class weights_loader(loader):
+    """one who understands .weights files"""
+
+    _W_ORDER = dict({  # order of param flattened into .weights file
+        'convolutional': [
+            'biases', 'gamma', 'moving_mean', 'moving_variance', 'kernel'
+        ],
+        'connected': ['biases', 'weights'],
+        'local': ['biases', 'kernels']
+    })
+
+    def load(self, path, src_layers):
+        self.src_layers = src_layers
+        walker = weights_walker(path)
+
+        for i, layer in enumerate(src_layers):
+            if layer.type not in self.VAR_LAYER: continue
+            self.src_key.append([layer])
+
+            if walker.eof:
+                new = None
+            else:
+                args = layer.signature
+                new = dark.darknet.create_darkop(*args)
+            self.vals.append(new)
+
+            if new is None: continue
+            order = self._W_ORDER[new.type]
+            for par in order:
+                if par not in new.wshape: continue
+                val = walker.walk(new.wsize[par])
+                new.w[par] = val
+            new.finalize(walker.transpose)
+
+        if walker.path is not None:
+            assert walker.offset == walker.size, \
+                'expect {} bytes, found {}'.format(
+                    walker.offset, walker.size)
+            print('Successfully identified {} bytes'.format(
+                walker.offset))
+
+
+class checkpoint_loader(loader):
+    """
+    one who understands .ckpt files, very much
+    """
+
+    def load(self, ckpt, ignore):
+        meta = ckpt + '.meta'
+        with tf.Graph().as_default() as graph:
+            with tf.Session().as_default() as sess:
+                saver = tf.train.import_meta_graph(meta)
+                saver.restore(sess, ckpt)
+                for var in tf.global_variables():
+                    name = var.name.split(':')[0]
+                    packet = [name, var.get_shape().as_list()]
+                    self.src_key += [packet]
+                    self.vals += [var.eval(sess)]
+
+
+def create_loader(path, cfg=None):
+    if path is None:
+        load_type = weights_loader
+    elif '.weights' in path:
+        load_type = weights_loader
+    else:
+        load_type = checkpoint_loader
+
+    return load_type(path, cfg)
+
+
+class weights_walker(object):
+    """incremental reader of float32 binary files"""
+
+    def __init__(self, path):
+        self.eof = False  # end of file
+        self.path = path  # current pos
+        if path is None:
+            self.eof = True
+            return
+        else:
+            self.size = os.path.getsize(path)  # save the path
+            major, minor, revision, seen = np.memmap(path,
+                                                     shape=(), mode='r', offset=0,
+                                                     dtype='({})i4,'.format(4))
+            self.transpose = major > 1000 or minor > 1000
+            self.offset = 16
+
+    def walk(self, size):
+        if self.eof: return None
+        end_point = self.offset + 4 * size
+        assert end_point <= self.size, \
+            'Over-read {}'.format(self.path)
+
+        float32_1D_array = np.memmap(
+            self.path, shape=(), mode='r',
+            offset=self.offset,
+            dtype='({})float32,'.format(size)
+        )
+
+        self.offset = end_point
+        if end_point == self.size:
+            self.eof = True
+        return float32_1D_array
+
+
+def model_name(file_path):
+    file_name = basename(file_path)
+    ext = str()
+    if '.' in file_name:  # exclude extension
+        file_name = file_name.split('.')
+        ext = file_name[-1]
+        file_name = '.'.join(file_name[:-1])
+    if ext == str() or ext == 'meta':  # ckpt file
+        file_name = file_name.split('-')
+        num = int(file_name[-1])
+        return '-'.join(file_name[:-1])
+    if ext == 'weights':
+        return file_name