Diff of /utils/tf_serving.py [000000] .. [4df946]

Switch to side-by-side view

--- a
+++ b/utils/tf_serving.py
@@ -0,0 +1,56 @@
+"""
+Basic components for using TF Serving.
+"""
+import os
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+import io
+import numpy as np
+from PIL import Image
+import grpc
+import tensorflow as tf
+from grpc.beta import implementations
+from tensorflow_serving.apis import predict_pb2
+from tensorflow_serving.apis import prediction_service_pb2
+from grpc._cython import cygrpc
+import config
+
+
+def decode_tensor(contents, downsample_factor):
+    images = [Image.open(io.BytesIO(content)) for content in contents]
+    dsize = (images[0].size[0] * downsample_factor, images[0].size[1] * downsample_factor)
+    images = [image.resize(dsize, Image.BILINEAR) for image in images]
+    mtx = np.array([np.asarray(image, dtype=np.uint8) for image in images], dtype=np.uint8)
+    mtx = mtx.transpose((1, 2, 0))
+    return mtx
+
+
+class TFServing(object):
+    """
+    Tensorflow Serving client, send prediction request.
+    """
+
+    def __init__(self, host, port):
+        super(TFServing, self).__init__()
+
+        channel = self._insecure_channel(host, port)
+        self._stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
+
+    def _insecure_channel(self, host, port):
+        channel = grpc.insecure_channel(
+            target=host if port is None else '{}:{}'.format(host, port),
+            options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
+                     (cygrpc.ChannelArgKey.max_receive_message_length, -1)])
+        return grpc.beta.implementations.Channel(channel)
+
+    def predict(self, image_input, model_name):
+        request = predict_pb2.PredictRequest()
+        request.model_spec.name = model_name
+        request.inputs[config.INPUT_KEY].CopyFrom(tf.contrib.util.make_tensor_proto(
+            image_input.astype(np.uint8, copy=False)))
+        try:
+            result = self._stub.Predict(request, 1000.0)
+            image_prob = np.array(result.outputs[config.PREDICT_KEY].int_val)
+        except Exception as e:
+            raise e
+        else:
+            return image_prob.astype(np.uint8)