Diff of /darkflow/utils/loader.py [000000] .. [d34869]

Switch to unified view

a b/darkflow/utils/loader.py
1
import tensorflow as tf
2
import os
3
from .. import dark
4
import numpy as np
5
from os.path import basename
6
7
tf = tf.compat.v1
8
tf.disable_v2_behavior()
9
10
11
class loader(object):
12
    """
13
    interface to work with both .weights and .ckpt files
14
    in loading / recollecting / resolving mode
15
    """
16
    VAR_LAYER = ['convolutional', 'connected', 'local',
17
                 'select', 'conv-select',
18
                 'extract', 'conv-extract']
19
20
    def __init__(self, *args):
21
        self.src_key = list()
22
        self.vals = list()
23
        self.load(*args)
24
25
    def __call__(self, key):
26
        for idx in range(len(key)):
27
            val = self.find(key, idx)
28
            if val is not None: return val
29
        return None
30
31
    def find(self, key, idx):
32
        up_to = min(len(self.src_key), 4)
33
        for i in range(up_to):
34
            key_b = self.src_key[i]
35
            if key_b[idx:] == key[idx:]:
36
                return self.yields(i)
37
        return None
38
39
    def yields(self, idx):
40
        del self.src_key[idx]
41
        temp = self.vals[idx]
42
        del self.vals[idx]
43
        return temp
44
45
46
class weights_loader(loader):
47
    """one who understands .weights files"""
48
49
    _W_ORDER = dict({  # order of param flattened into .weights file
50
        'convolutional': [
51
            'biases', 'gamma', 'moving_mean', 'moving_variance', 'kernel'
52
        ],
53
        'connected': ['biases', 'weights'],
54
        'local': ['biases', 'kernels']
55
    })
56
57
    def load(self, path, src_layers):
58
        self.src_layers = src_layers
59
        walker = weights_walker(path)
60
61
        for i, layer in enumerate(src_layers):
62
            if layer.type not in self.VAR_LAYER: continue
63
            self.src_key.append([layer])
64
65
            if walker.eof:
66
                new = None
67
            else:
68
                args = layer.signature
69
                new = dark.darknet.create_darkop(*args)
70
            self.vals.append(new)
71
72
            if new is None: continue
73
            order = self._W_ORDER[new.type]
74
            for par in order:
75
                if par not in new.wshape: continue
76
                val = walker.walk(new.wsize[par])
77
                new.w[par] = val
78
            new.finalize(walker.transpose)
79
80
        if walker.path is not None:
81
            assert walker.offset == walker.size, \
82
                'expect {} bytes, found {}'.format(
83
                    walker.offset, walker.size)
84
            print('Successfully identified {} bytes'.format(
85
                walker.offset))
86
87
88
class checkpoint_loader(loader):
89
    """
90
    one who understands .ckpt files, very much
91
    """
92
93
    def load(self, ckpt, ignore):
94
        meta = ckpt + '.meta'
95
        with tf.Graph().as_default() as graph:
96
            with tf.Session().as_default() as sess:
97
                saver = tf.train.import_meta_graph(meta)
98
                saver.restore(sess, ckpt)
99
                for var in tf.global_variables():
100
                    name = var.name.split(':')[0]
101
                    packet = [name, var.get_shape().as_list()]
102
                    self.src_key += [packet]
103
                    self.vals += [var.eval(sess)]
104
105
106
def create_loader(path, cfg=None):
107
    if path is None:
108
        load_type = weights_loader
109
    elif '.weights' in path:
110
        load_type = weights_loader
111
    else:
112
        load_type = checkpoint_loader
113
114
    return load_type(path, cfg)
115
116
117
class weights_walker(object):
118
    """incremental reader of float32 binary files"""
119
120
    def __init__(self, path):
121
        self.eof = False  # end of file
122
        self.path = path  # current pos
123
        if path is None:
124
            self.eof = True
125
            return
126
        else:
127
            self.size = os.path.getsize(path)  # save the path
128
            major, minor, revision, seen = np.memmap(path,
129
                                                     shape=(), mode='r', offset=0,
130
                                                     dtype='({})i4,'.format(4))
131
            self.transpose = major > 1000 or minor > 1000
132
            self.offset = 16
133
134
    def walk(self, size):
135
        if self.eof: return None
136
        end_point = self.offset + 4 * size
137
        assert end_point <= self.size, \
138
            'Over-read {}'.format(self.path)
139
140
        float32_1D_array = np.memmap(
141
            self.path, shape=(), mode='r',
142
            offset=self.offset,
143
            dtype='({})float32,'.format(size)
144
        )
145
146
        self.offset = end_point
147
        if end_point == self.size:
148
            self.eof = True
149
        return float32_1D_array
150
151
152
def model_name(file_path):
153
    file_name = basename(file_path)
154
    ext = str()
155
    if '.' in file_name:  # exclude extension
156
        file_name = file_name.split('.')
157
        ext = file_name[-1]
158
        file_name = '.'.join(file_name[:-1])
159
    if ext == str() or ext == 'meta':  # ckpt file
160
        file_name = file_name.split('-')
161
        num = int(file_name[-1])
162
        return '-'.join(file_name[:-1])
163
    if ext == 'weights':
164
        return file_name