a b/darkflow/dark/darknet.py
1
from ..utils.process import cfg_yielder
2
from .darkop import create_darkop
3
from ..utils import loader
4
import warnings
5
import time
6
import os
7
8
9
class Darknet(object):
10
    _EXT = '.weights'
11
12
    def __init__(self, FLAGS):
13
        self.get_weight_src(FLAGS)
14
        self.modify = False
15
16
        print('Parsing {}'.format(self.src_cfg))
17
        src_parsed = self.parse_cfg(self.src_cfg, FLAGS)
18
        self.src_meta, self.src_layers = src_parsed
19
20
        if self.src_cfg == FLAGS.model:
21
            self.meta, self.layers = src_parsed
22
        else:
23
            print('Parsing {}'.format(FLAGS.model))
24
            des_parsed = self.parse_cfg(FLAGS.model, FLAGS)
25
            self.meta, self.layers = des_parsed
26
27
        self.load_weights()
28
29
    def get_weight_src(self, FLAGS):
30
        """
31
        analyse FLAGS.load to know where is the 
32
        source binary and what is its config.
33
        can be: None, FLAGS.model, or some other
34
        """
35
        self.src_bin = FLAGS.model + self._EXT
36
        self.src_bin = FLAGS.binary + self.src_bin
37
        self.src_bin = os.path.abspath(self.src_bin)
38
        exist = os.path.isfile(self.src_bin)
39
40
        if FLAGS.load == str(): FLAGS.load = int()
41
        if type(FLAGS.load) is int:
42
            self.src_cfg = FLAGS.model
43
            if FLAGS.load:
44
                self.src_bin = None
45
            elif not exist:
46
                self.src_bin = None
47
        else:
48
            assert os.path.isfile(FLAGS.load), \
49
                '{} not found'.format(FLAGS.load)
50
            self.src_bin = FLAGS.load
51
            name = loader.model_name(FLAGS.load)
52
            cfg_path = os.path.join(FLAGS.config, name + '.cfg')
53
            if not os.path.isfile(cfg_path):
54
                warnings.warn(
55
                    '{} not found, use {} instead'.format(
56
                        cfg_path, FLAGS.model))
57
                cfg_path = FLAGS.model
58
            self.src_cfg = cfg_path
59
            FLAGS.load = int()
60
61
    def parse_cfg(self, model, FLAGS):
62
        """
63
        return a list of `layers` objects (darkop.py)
64
        given path to binaries/ and configs/
65
        """
66
        args = [model, FLAGS.binary]
67
        cfg_layers = cfg_yielder(*args)
68
        meta = dict();
69
        layers = list()
70
        for i, info in enumerate(cfg_layers):
71
            if i == 0:
72
                meta = info; continue
73
            else:
74
                new = create_darkop(*info)
75
            layers.append(new)
76
        return meta, layers
77
78
    def load_weights(self):
79
        """
80
        Use `layers` and Loader to load .weights file
81
        """
82
        print('Loading {} ...'.format(self.src_bin))
83
        start = time.time()
84
85
        args = [self.src_bin, self.src_layers]
86
        wgts_loader = loader.create_loader(*args)
87
        for layer in self.layers: layer.load(wgts_loader)
88
89
        stop = time.time()
90
        print('Finished in {}s'.format(stop - start))