|
a |
|
b/darkflow/dark/darknet.py |
|
|
1 |
from ..utils.process import cfg_yielder |
|
|
2 |
from .darkop import create_darkop |
|
|
3 |
from ..utils import loader |
|
|
4 |
import warnings |
|
|
5 |
import time |
|
|
6 |
import os |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class Darknet(object): |
|
|
10 |
_EXT = '.weights' |
|
|
11 |
|
|
|
12 |
def __init__(self, FLAGS): |
|
|
13 |
self.get_weight_src(FLAGS) |
|
|
14 |
self.modify = False |
|
|
15 |
|
|
|
16 |
print('Parsing {}'.format(self.src_cfg)) |
|
|
17 |
src_parsed = self.parse_cfg(self.src_cfg, FLAGS) |
|
|
18 |
self.src_meta, self.src_layers = src_parsed |
|
|
19 |
|
|
|
20 |
if self.src_cfg == FLAGS.model: |
|
|
21 |
self.meta, self.layers = src_parsed |
|
|
22 |
else: |
|
|
23 |
print('Parsing {}'.format(FLAGS.model)) |
|
|
24 |
des_parsed = self.parse_cfg(FLAGS.model, FLAGS) |
|
|
25 |
self.meta, self.layers = des_parsed |
|
|
26 |
|
|
|
27 |
self.load_weights() |
|
|
28 |
|
|
|
29 |
def get_weight_src(self, FLAGS): |
|
|
30 |
""" |
|
|
31 |
analyse FLAGS.load to know where is the |
|
|
32 |
source binary and what is its config. |
|
|
33 |
can be: None, FLAGS.model, or some other |
|
|
34 |
""" |
|
|
35 |
self.src_bin = FLAGS.model + self._EXT |
|
|
36 |
self.src_bin = FLAGS.binary + self.src_bin |
|
|
37 |
self.src_bin = os.path.abspath(self.src_bin) |
|
|
38 |
exist = os.path.isfile(self.src_bin) |
|
|
39 |
|
|
|
40 |
if FLAGS.load == str(): FLAGS.load = int() |
|
|
41 |
if type(FLAGS.load) is int: |
|
|
42 |
self.src_cfg = FLAGS.model |
|
|
43 |
if FLAGS.load: |
|
|
44 |
self.src_bin = None |
|
|
45 |
elif not exist: |
|
|
46 |
self.src_bin = None |
|
|
47 |
else: |
|
|
48 |
assert os.path.isfile(FLAGS.load), \ |
|
|
49 |
'{} not found'.format(FLAGS.load) |
|
|
50 |
self.src_bin = FLAGS.load |
|
|
51 |
name = loader.model_name(FLAGS.load) |
|
|
52 |
cfg_path = os.path.join(FLAGS.config, name + '.cfg') |
|
|
53 |
if not os.path.isfile(cfg_path): |
|
|
54 |
warnings.warn( |
|
|
55 |
'{} not found, use {} instead'.format( |
|
|
56 |
cfg_path, FLAGS.model)) |
|
|
57 |
cfg_path = FLAGS.model |
|
|
58 |
self.src_cfg = cfg_path |
|
|
59 |
FLAGS.load = int() |
|
|
60 |
|
|
|
61 |
def parse_cfg(self, model, FLAGS): |
|
|
62 |
""" |
|
|
63 |
return a list of `layers` objects (darkop.py) |
|
|
64 |
given path to binaries/ and configs/ |
|
|
65 |
""" |
|
|
66 |
args = [model, FLAGS.binary] |
|
|
67 |
cfg_layers = cfg_yielder(*args) |
|
|
68 |
meta = dict(); |
|
|
69 |
layers = list() |
|
|
70 |
for i, info in enumerate(cfg_layers): |
|
|
71 |
if i == 0: |
|
|
72 |
meta = info; continue |
|
|
73 |
else: |
|
|
74 |
new = create_darkop(*info) |
|
|
75 |
layers.append(new) |
|
|
76 |
return meta, layers |
|
|
77 |
|
|
|
78 |
def load_weights(self): |
|
|
79 |
""" |
|
|
80 |
Use `layers` and Loader to load .weights file |
|
|
81 |
""" |
|
|
82 |
print('Loading {} ...'.format(self.src_bin)) |
|
|
83 |
start = time.time() |
|
|
84 |
|
|
|
85 |
args = [self.src_bin, self.src_layers] |
|
|
86 |
wgts_loader = loader.create_loader(*args) |
|
|
87 |
for layer in self.layers: layer.load(wgts_loader) |
|
|
88 |
|
|
|
89 |
stop = time.time() |
|
|
90 |
print('Finished in {}s'.format(stop - start)) |