[e8481a]: / src / pipeline / inference.py

Download this file

151 lines (130 with data), 5.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Tensorflow inference engine wrapper."""
import logging
import os
import numpy as np
from tflite_runtime.interpreter import Interpreter
from tflite_runtime.interpreter import load_delegate
log = logging.getLogger(__name__)
def _get_edgetpu_interpreter(model=None): # pragma: no cover
# Note: Looking for ideas how to test Coral EdgeTPU dependent code
# in a cloud CI environment such as Travis CI and Github
tf_interpreter = None
if model:
try:
edgetpu_delegate = load_delegate('libedgetpu.so.1.0')
assert edgetpu_delegate
tf_interpreter = Interpreter(
model_path=model,
experimental_delegates=[edgetpu_delegate]
)
log.debug('EdgeTPU available. Will use EdgeTPU model.')
except Exception as e:
log.debug('EdgeTPU init error: %r', e)
# log.debug(stacktrace())
return tf_interpreter
class TFInferenceEngine:
"""Thin wrapper around TFLite Interpreter.
The official TFLite API is moving fast and still changes frequently.
This class intends to abstract out underlying TF changes to some extend.
It dynamically detects if EdgeTPU is available and uses it.
Otherwise falls back to TFLite Runtime.
"""
def __init__(self,
model=None,
labels=None,
confidence_threshold=0.8,
**kwargs
):
"""Create an instance of Tensorflow inference engine.
:Parameters:
----------
model: dict
{
'tflite': path,
'edgetpu': path,
}
Where path is of type string and points to the
location of frozen graph file (AI model).
labels : string
Location of file with model labels.
confidence_threshold : float
Inference confidence threshold.
"""
assert model
assert model['tflite'], 'TFLite AI model path required.'
model_tflite = model['tflite']
assert os.path.isfile(model_tflite), \
'TFLite AI model file does not exist: {}' \
.format(model_tflite)
self._model_tflite_path = model_tflite
model_edgetpu = model.get('edgetpu', None)
if model_edgetpu:
assert os.path.isfile(model_edgetpu), \
'EdgeTPU AI model file does not exist: {}' \
.format(model_edgetpu)
self._model_edgetpu_path = model_edgetpu
assert labels, 'AI model labels path required.'
assert os.path.isfile(labels), \
'AI model labels file does not exist: {}' \
.format(labels)
self._model_labels_path = labels
self._confidence_threshold = confidence_threshold
# log.info('Loading AI model:\n'
# 'TFLite graph: %r\n'
# 'EdgeTPU graph: %r\n'
# 'Labels %r.'
# 'Condidence threshod: %.0f%%'
# 'top-k: %d',
# model_tflite,
# model_edgetpu,
# labels,
# confidence_threshold*100)
# EdgeTPU is not available in testing and other environments
# load dynamically as needed
# edgetpu_class = 'DetectionEngine'
# module_object = import_module('edgetpu.detection.engine',
# packaage=edgetpu_class)
# target_class = getattr(module_object, edgetpu_class)
self._tf_interpreter = _get_edgetpu_interpreter(model=model_edgetpu)
if not self._tf_interpreter:
log.debug('EdgeTPU not available. Will use TFLite CPU runtime.')
self._tf_interpreter = Interpreter(model_path=model_tflite)
assert self._tf_interpreter
self._tf_interpreter.allocate_tensors()
# check the type of the input tensor
self._tf_input_details = self._tf_interpreter.get_input_details()
self._tf_output_details = self._tf_interpreter.get_output_details()
self._tf_is_quantized_model = \
self.input_details[0]['dtype'] != np.float32
@property
def input_details(self):
return self._tf_input_details
@property
def output_details(self):
return self._tf_output_details
@property
def is_quantized(self):
return self._tf_is_quantized_model
@property
def confidence_threshold(self):
"""
Inference confidence threshold.
:Returns:
-------
float
Confidence threshold for inference results.
Only results at or above
this threshold should be returned by each engine inference.
"""
return self._confidence_threshold
def infer(self):
"""Invoke model inference on current input tensor."""
return self._tf_interpreter.invoke()
def set_tensor(self, index=None, tensor_data=None):
"""Set tensor data at given reference index."""
assert isinstance(index, int)
self._tf_interpreter.set_tensor(index, tensor_data)
def get_tensor(self, index=None):
"""Return tensor data at given reference index."""
assert isinstance(index, int)
return self._tf_interpreter.get_tensor(index)