|
a |
|
b/utils/tf_serving.py |
|
|
1 |
""" |
|
|
2 |
Basic components for using TF Serving. |
|
|
3 |
""" |
|
|
4 |
import os |
|
|
5 |
os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
6 |
import io |
|
|
7 |
import numpy as np |
|
|
8 |
from PIL import Image |
|
|
9 |
import grpc |
|
|
10 |
import tensorflow as tf |
|
|
11 |
from grpc.beta import implementations |
|
|
12 |
from tensorflow_serving.apis import predict_pb2 |
|
|
13 |
from tensorflow_serving.apis import prediction_service_pb2 |
|
|
14 |
from grpc._cython import cygrpc |
|
|
15 |
import config |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def decode_tensor(contents, downsample_factor): |
|
|
19 |
images = [Image.open(io.BytesIO(content)) for content in contents] |
|
|
20 |
dsize = (images[0].size[0] * downsample_factor, images[0].size[1] * downsample_factor) |
|
|
21 |
images = [image.resize(dsize, Image.BILINEAR) for image in images] |
|
|
22 |
mtx = np.array([np.asarray(image, dtype=np.uint8) for image in images], dtype=np.uint8) |
|
|
23 |
mtx = mtx.transpose((1, 2, 0)) |
|
|
24 |
return mtx |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
class TFServing(object): |
|
|
28 |
""" |
|
|
29 |
Tensorflow Serving client, send prediction request. |
|
|
30 |
""" |
|
|
31 |
|
|
|
32 |
def __init__(self, host, port): |
|
|
33 |
super(TFServing, self).__init__() |
|
|
34 |
|
|
|
35 |
channel = self._insecure_channel(host, port) |
|
|
36 |
self._stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) |
|
|
37 |
|
|
|
38 |
def _insecure_channel(self, host, port): |
|
|
39 |
channel = grpc.insecure_channel( |
|
|
40 |
target=host if port is None else '{}:{}'.format(host, port), |
|
|
41 |
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), |
|
|
42 |
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]) |
|
|
43 |
return grpc.beta.implementations.Channel(channel) |
|
|
44 |
|
|
|
45 |
def predict(self, image_input, model_name): |
|
|
46 |
request = predict_pb2.PredictRequest() |
|
|
47 |
request.model_spec.name = model_name |
|
|
48 |
request.inputs[config.INPUT_KEY].CopyFrom(tf.contrib.util.make_tensor_proto( |
|
|
49 |
image_input.astype(np.uint8, copy=False))) |
|
|
50 |
try: |
|
|
51 |
result = self._stub.Predict(request, 1000.0) |
|
|
52 |
image_prob = np.array(result.outputs[config.PREDICT_KEY].int_val) |
|
|
53 |
except Exception as e: |
|
|
54 |
raise e |
|
|
55 |
else: |
|
|
56 |
return image_prob.astype(np.uint8) |