--- a +++ b/darkflow/utils/loader.py @@ -0,0 +1,164 @@ +import tensorflow as tf +import os +from .. import dark +import numpy as np +from os.path import basename + +tf = tf.compat.v1 +tf.disable_v2_behavior() + + +class loader(object): + """ + interface to work with both .weights and .ckpt files + in loading / recollecting / resolving mode + """ + VAR_LAYER = ['convolutional', 'connected', 'local', + 'select', 'conv-select', + 'extract', 'conv-extract'] + + def __init__(self, *args): + self.src_key = list() + self.vals = list() + self.load(*args) + + def __call__(self, key): + for idx in range(len(key)): + val = self.find(key, idx) + if val is not None: return val + return None + + def find(self, key, idx): + up_to = min(len(self.src_key), 4) + for i in range(up_to): + key_b = self.src_key[i] + if key_b[idx:] == key[idx:]: + return self.yields(i) + return None + + def yields(self, idx): + del self.src_key[idx] + temp = self.vals[idx] + del self.vals[idx] + return temp + + +class weights_loader(loader): + """one who understands .weights files""" + + _W_ORDER = dict({ # order of param flattened into .weights file + 'convolutional': [ + 'biases', 'gamma', 'moving_mean', 'moving_variance', 'kernel' + ], + 'connected': ['biases', 'weights'], + 'local': ['biases', 'kernels'] + }) + + def load(self, path, src_layers): + self.src_layers = src_layers + walker = weights_walker(path) + + for i, layer in enumerate(src_layers): + if layer.type not in self.VAR_LAYER: continue + self.src_key.append([layer]) + + if walker.eof: + new = None + else: + args = layer.signature + new = dark.darknet.create_darkop(*args) + self.vals.append(new) + + if new is None: continue + order = self._W_ORDER[new.type] + for par in order: + if par not in new.wshape: continue + val = walker.walk(new.wsize[par]) + new.w[par] = val + new.finalize(walker.transpose) + + if walker.path is not None: + assert walker.offset == walker.size, \ + 'expect {} bytes, found {}'.format( + walker.offset, walker.size) + print('Successfully identified {} bytes'.format( + walker.offset)) + + +class checkpoint_loader(loader): + """ + one who understands .ckpt files, very much + """ + + def load(self, ckpt, ignore): + meta = ckpt + '.meta' + with tf.Graph().as_default() as graph: + with tf.Session().as_default() as sess: + saver = tf.train.import_meta_graph(meta) + saver.restore(sess, ckpt) + for var in tf.global_variables(): + name = var.name.split(':')[0] + packet = [name, var.get_shape().as_list()] + self.src_key += [packet] + self.vals += [var.eval(sess)] + + +def create_loader(path, cfg=None): + if path is None: + load_type = weights_loader + elif '.weights' in path: + load_type = weights_loader + else: + load_type = checkpoint_loader + + return load_type(path, cfg) + + +class weights_walker(object): + """incremental reader of float32 binary files""" + + def __init__(self, path): + self.eof = False # end of file + self.path = path # current pos + if path is None: + self.eof = True + return + else: + self.size = os.path.getsize(path) # save the path + major, minor, revision, seen = np.memmap(path, + shape=(), mode='r', offset=0, + dtype='({})i4,'.format(4)) + self.transpose = major > 1000 or minor > 1000 + self.offset = 16 + + def walk(self, size): + if self.eof: return None + end_point = self.offset + 4 * size + assert end_point <= self.size, \ + 'Over-read {}'.format(self.path) + + float32_1D_array = np.memmap( + self.path, shape=(), mode='r', + offset=self.offset, + dtype='({})float32,'.format(size) + ) + + self.offset = end_point + if end_point == self.size: + self.eof = True + return float32_1D_array + + +def model_name(file_path): + file_name = basename(file_path) + ext = str() + if '.' in file_name: # exclude extension + file_name = file_name.split('.') + ext = file_name[-1] + file_name = '.'.join(file_name[:-1]) + if ext == str() or ext == 'meta': # ckpt file + file_name = file_name.split('-') + num = int(file_name[-1]) + return '-'.join(file_name[:-1]) + if ext == 'weights': + return file_name