Diff of /utils/triton.py [000000] .. [190ca4]

Switch to unified view

a b/utils/triton.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
""" Utils to interact with the Triton Inference Server
3
"""
4
5
import typing
6
from urllib.parse import urlparse
7
8
import torch
9
10
11
class TritonRemoteModel:
12
    """ A wrapper over a model served by the Triton Inference Server. It can
13
    be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
14
    as input and returns them as outputs.
15
    """
16
17
    def __init__(self, url: str):
18
        """
19
        Keyword arguments:
20
        url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
21
        """
22
23
        parsed_url = urlparse(url)
24
        if parsed_url.scheme == 'grpc':
25
            from tritonclient.grpc import InferenceServerClient, InferInput
26
27
            self.client = InferenceServerClient(parsed_url.netloc)  # Triton GRPC client
28
            model_repository = self.client.get_model_repository_index()
29
            self.model_name = model_repository.models[0].name
30
            self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
31
32
            def create_input_placeholders() -> typing.List[InferInput]:
33
                return [
34
                    InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
35
36
        else:
37
            from tritonclient.http import InferenceServerClient, InferInput
38
39
            self.client = InferenceServerClient(parsed_url.netloc)  # Triton HTTP client
40
            model_repository = self.client.get_model_repository_index()
41
            self.model_name = model_repository[0]['name']
42
            self.metadata = self.client.get_model_metadata(self.model_name)
43
44
            def create_input_placeholders() -> typing.List[InferInput]:
45
                return [
46
                    InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
47
48
        self._create_input_placeholders_fn = create_input_placeholders
49
50
    @property
51
    def runtime(self):
52
        """Returns the model runtime"""
53
        return self.metadata.get('backend', self.metadata.get('platform'))
54
55
    def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
56
        """ Invokes the model. Parameters can be provided via args or kwargs.
57
        args, if provided, are assumed to match the order of inputs of the model.
58
        kwargs are matched with the model input names.
59
        """
60
        inputs = self._create_inputs(*args, **kwargs)
61
        response = self.client.infer(model_name=self.model_name, inputs=inputs)
62
        result = []
63
        for output in self.metadata['outputs']:
64
            tensor = torch.as_tensor(response.as_numpy(output['name']))
65
            result.append(tensor)
66
        return result[0] if len(result) == 1 else result
67
68
    def _create_inputs(self, *args, **kwargs):
69
        args_len, kwargs_len = len(args), len(kwargs)
70
        if not args_len and not kwargs_len:
71
            raise RuntimeError('No inputs provided.')
72
        if args_len and kwargs_len:
73
            raise RuntimeError('Cannot specify args and kwargs at the same time')
74
75
        placeholders = self._create_input_placeholders_fn()
76
        if args_len:
77
            if args_len != len(placeholders):
78
                raise RuntimeError(f'Expected {len(placeholders)} inputs, got {args_len}.')
79
            for input, value in zip(placeholders, args):
80
                input.set_data_from_numpy(value.cpu().numpy())
81
        else:
82
            for input in placeholders:
83
                value = kwargs[input.name]
84
                input.set_data_from_numpy(value.cpu().numpy())
85
        return placeholders