Diff of /darkflow/dark/layer.py [000000] .. [d34869]

Switch to unified view

a b/darkflow/dark/layer.py
1
from ..utils import loader
2
import numpy as np
3
4
5
class Layer(object):
6
7
    def __init__(self, *args):
8
        self._signature = list(args)
9
        self.type = list(args)[0]
10
        self.number = list(args)[1]
11
12
        self.w = dict()  # weights
13
        self.h = dict()  # placeholders
14
        self.wshape = dict()  # weight shape
15
        self.wsize = dict()  # weight size
16
        self.setup(*args[2:])  # set attr up
17
        self.present()
18
        for var in self.wshape:
19
            shp = self.wshape[var]
20
            size = np.prod(shp)
21
            self.wsize[var] = size
22
23
    def load(self, src_loader):
24
        var_lay = src_loader.VAR_LAYER
25
        if self.type not in var_lay: return
26
27
        src_type = type(src_loader)
28
        if src_type is loader.weights_loader:
29
            wdict = self.load_weights(src_loader)
30
        else:
31
            wdict = self.load_ckpt(src_loader)
32
        if wdict is not None:
33
            self.recollect(wdict)
34
35
    def load_weights(self, src_loader):
36
        val = src_loader([self.presenter])
37
        if val is None:
38
            return None
39
        else:
40
            return val.w
41
42
    def load_ckpt(self, src_loader):
43
        result = dict()
44
        presenter = self.presenter
45
        for var in presenter.wshape:
46
            name = presenter.varsig(var)
47
            shape = presenter.wshape[var]
48
            key = [name, shape]
49
            val = src_loader(key)
50
            result[var] = val
51
        return result
52
53
    @property
54
    def signature(self):
55
        return self._signature
56
57
    # For comparing two layers
58
    def __eq__(self, other):
59
        return self.signature == other.signature
60
61
    def __ne__(self, other):
62
        return not self.__eq__(other)
63
64
    def varsig(self, var):
65
        if var not in self.wshape:
66
            return None
67
        sig = str(self.number)
68
        sig += '-' + self.type
69
        sig += '/' + var
70
        return sig
71
72
    def recollect(self, w):
73
        self.w = w
74
75
    def present(self):
76
        self.presenter = self
77
78
    def setup(self, *args):
79
        pass
80
81
    def finalize(self):
82
        pass