Diff of /utils/tf_serving.py [000000] .. [4df946]

Switch to unified view

a b/utils/tf_serving.py
1
"""
2
Basic components for using TF Serving.
3
"""
4
import os
5
os.environ["CUDA_VISIBLE_DEVICES"] = ""
6
import io
7
import numpy as np
8
from PIL import Image
9
import grpc
10
import tensorflow as tf
11
from grpc.beta import implementations
12
from tensorflow_serving.apis import predict_pb2
13
from tensorflow_serving.apis import prediction_service_pb2
14
from grpc._cython import cygrpc
15
import config
16
17
18
def decode_tensor(contents, downsample_factor):
19
    images = [Image.open(io.BytesIO(content)) for content in contents]
20
    dsize = (images[0].size[0] * downsample_factor, images[0].size[1] * downsample_factor)
21
    images = [image.resize(dsize, Image.BILINEAR) for image in images]
22
    mtx = np.array([np.asarray(image, dtype=np.uint8) for image in images], dtype=np.uint8)
23
    mtx = mtx.transpose((1, 2, 0))
24
    return mtx
25
26
27
class TFServing(object):
28
    """
29
    Tensorflow Serving client, send prediction request.
30
    """
31
32
    def __init__(self, host, port):
33
        super(TFServing, self).__init__()
34
35
        channel = self._insecure_channel(host, port)
36
        self._stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
37
38
    def _insecure_channel(self, host, port):
39
        channel = grpc.insecure_channel(
40
            target=host if port is None else '{}:{}'.format(host, port),
41
            options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
42
                     (cygrpc.ChannelArgKey.max_receive_message_length, -1)])
43
        return grpc.beta.implementations.Channel(channel)
44
45
    def predict(self, image_input, model_name):
46
        request = predict_pb2.PredictRequest()
47
        request.model_spec.name = model_name
48
        request.inputs[config.INPUT_KEY].CopyFrom(tf.contrib.util.make_tensor_proto(
49
            image_input.astype(np.uint8, copy=False)))
50
        try:
51
            result = self._stub.Predict(request, 1000.0)
52
            image_prob = np.array(result.outputs[config.PREDICT_KEY].int_val)
53
        except Exception as e:
54
            raise e
55
        else:
56
            return image_prob.astype(np.uint8)