|
a |
|
b/darkflow/utils/loader.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import os |
|
|
3 |
from .. import dark |
|
|
4 |
import numpy as np |
|
|
5 |
from os.path import basename |
|
|
6 |
|
|
|
7 |
tf = tf.compat.v1 |
|
|
8 |
tf.disable_v2_behavior() |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class loader(object): |
|
|
12 |
""" |
|
|
13 |
interface to work with both .weights and .ckpt files |
|
|
14 |
in loading / recollecting / resolving mode |
|
|
15 |
""" |
|
|
16 |
VAR_LAYER = ['convolutional', 'connected', 'local', |
|
|
17 |
'select', 'conv-select', |
|
|
18 |
'extract', 'conv-extract'] |
|
|
19 |
|
|
|
20 |
def __init__(self, *args): |
|
|
21 |
self.src_key = list() |
|
|
22 |
self.vals = list() |
|
|
23 |
self.load(*args) |
|
|
24 |
|
|
|
25 |
def __call__(self, key): |
|
|
26 |
for idx in range(len(key)): |
|
|
27 |
val = self.find(key, idx) |
|
|
28 |
if val is not None: return val |
|
|
29 |
return None |
|
|
30 |
|
|
|
31 |
def find(self, key, idx): |
|
|
32 |
up_to = min(len(self.src_key), 4) |
|
|
33 |
for i in range(up_to): |
|
|
34 |
key_b = self.src_key[i] |
|
|
35 |
if key_b[idx:] == key[idx:]: |
|
|
36 |
return self.yields(i) |
|
|
37 |
return None |
|
|
38 |
|
|
|
39 |
def yields(self, idx): |
|
|
40 |
del self.src_key[idx] |
|
|
41 |
temp = self.vals[idx] |
|
|
42 |
del self.vals[idx] |
|
|
43 |
return temp |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
class weights_loader(loader): |
|
|
47 |
"""one who understands .weights files""" |
|
|
48 |
|
|
|
49 |
_W_ORDER = dict({ # order of param flattened into .weights file |
|
|
50 |
'convolutional': [ |
|
|
51 |
'biases', 'gamma', 'moving_mean', 'moving_variance', 'kernel' |
|
|
52 |
], |
|
|
53 |
'connected': ['biases', 'weights'], |
|
|
54 |
'local': ['biases', 'kernels'] |
|
|
55 |
}) |
|
|
56 |
|
|
|
57 |
def load(self, path, src_layers): |
|
|
58 |
self.src_layers = src_layers |
|
|
59 |
walker = weights_walker(path) |
|
|
60 |
|
|
|
61 |
for i, layer in enumerate(src_layers): |
|
|
62 |
if layer.type not in self.VAR_LAYER: continue |
|
|
63 |
self.src_key.append([layer]) |
|
|
64 |
|
|
|
65 |
if walker.eof: |
|
|
66 |
new = None |
|
|
67 |
else: |
|
|
68 |
args = layer.signature |
|
|
69 |
new = dark.darknet.create_darkop(*args) |
|
|
70 |
self.vals.append(new) |
|
|
71 |
|
|
|
72 |
if new is None: continue |
|
|
73 |
order = self._W_ORDER[new.type] |
|
|
74 |
for par in order: |
|
|
75 |
if par not in new.wshape: continue |
|
|
76 |
val = walker.walk(new.wsize[par]) |
|
|
77 |
new.w[par] = val |
|
|
78 |
new.finalize(walker.transpose) |
|
|
79 |
|
|
|
80 |
if walker.path is not None: |
|
|
81 |
assert walker.offset == walker.size, \ |
|
|
82 |
'expect {} bytes, found {}'.format( |
|
|
83 |
walker.offset, walker.size) |
|
|
84 |
print('Successfully identified {} bytes'.format( |
|
|
85 |
walker.offset)) |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
class checkpoint_loader(loader): |
|
|
89 |
""" |
|
|
90 |
one who understands .ckpt files, very much |
|
|
91 |
""" |
|
|
92 |
|
|
|
93 |
def load(self, ckpt, ignore): |
|
|
94 |
meta = ckpt + '.meta' |
|
|
95 |
with tf.Graph().as_default() as graph: |
|
|
96 |
with tf.Session().as_default() as sess: |
|
|
97 |
saver = tf.train.import_meta_graph(meta) |
|
|
98 |
saver.restore(sess, ckpt) |
|
|
99 |
for var in tf.global_variables(): |
|
|
100 |
name = var.name.split(':')[0] |
|
|
101 |
packet = [name, var.get_shape().as_list()] |
|
|
102 |
self.src_key += [packet] |
|
|
103 |
self.vals += [var.eval(sess)] |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
def create_loader(path, cfg=None): |
|
|
107 |
if path is None: |
|
|
108 |
load_type = weights_loader |
|
|
109 |
elif '.weights' in path: |
|
|
110 |
load_type = weights_loader |
|
|
111 |
else: |
|
|
112 |
load_type = checkpoint_loader |
|
|
113 |
|
|
|
114 |
return load_type(path, cfg) |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
class weights_walker(object): |
|
|
118 |
"""incremental reader of float32 binary files""" |
|
|
119 |
|
|
|
120 |
def __init__(self, path): |
|
|
121 |
self.eof = False # end of file |
|
|
122 |
self.path = path # current pos |
|
|
123 |
if path is None: |
|
|
124 |
self.eof = True |
|
|
125 |
return |
|
|
126 |
else: |
|
|
127 |
self.size = os.path.getsize(path) # save the path |
|
|
128 |
major, minor, revision, seen = np.memmap(path, |
|
|
129 |
shape=(), mode='r', offset=0, |
|
|
130 |
dtype='({})i4,'.format(4)) |
|
|
131 |
self.transpose = major > 1000 or minor > 1000 |
|
|
132 |
self.offset = 16 |
|
|
133 |
|
|
|
134 |
def walk(self, size): |
|
|
135 |
if self.eof: return None |
|
|
136 |
end_point = self.offset + 4 * size |
|
|
137 |
assert end_point <= self.size, \ |
|
|
138 |
'Over-read {}'.format(self.path) |
|
|
139 |
|
|
|
140 |
float32_1D_array = np.memmap( |
|
|
141 |
self.path, shape=(), mode='r', |
|
|
142 |
offset=self.offset, |
|
|
143 |
dtype='({})float32,'.format(size) |
|
|
144 |
) |
|
|
145 |
|
|
|
146 |
self.offset = end_point |
|
|
147 |
if end_point == self.size: |
|
|
148 |
self.eof = True |
|
|
149 |
return float32_1D_array |
|
|
150 |
|
|
|
151 |
|
|
|
152 |
def model_name(file_path): |
|
|
153 |
file_name = basename(file_path) |
|
|
154 |
ext = str() |
|
|
155 |
if '.' in file_name: # exclude extension |
|
|
156 |
file_name = file_name.split('.') |
|
|
157 |
ext = file_name[-1] |
|
|
158 |
file_name = '.'.join(file_name[:-1]) |
|
|
159 |
if ext == str() or ext == 'meta': # ckpt file |
|
|
160 |
file_name = file_name.split('-') |
|
|
161 |
num = int(file_name[-1]) |
|
|
162 |
return '-'.join(file_name[:-1]) |
|
|
163 |
if ext == 'weights': |
|
|
164 |
return file_name |