Diff of /utils/triton.py [000000] .. [190ca4]

Switch to side-by-side view

--- a
+++ b/utils/triton.py
@@ -0,0 +1,85 @@
+# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
+""" Utils to interact with the Triton Inference Server
+"""
+
+import typing
+from urllib.parse import urlparse
+
+import torch
+
+
+class TritonRemoteModel:
+    """ A wrapper over a model served by the Triton Inference Server. It can
+    be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
+    as input and returns them as outputs.
+    """
+
+    def __init__(self, url: str):
+        """
+        Keyword arguments:
+        url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
+        """
+
+        parsed_url = urlparse(url)
+        if parsed_url.scheme == 'grpc':
+            from tritonclient.grpc import InferenceServerClient, InferInput
+
+            self.client = InferenceServerClient(parsed_url.netloc)  # Triton GRPC client
+            model_repository = self.client.get_model_repository_index()
+            self.model_name = model_repository.models[0].name
+            self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
+
+            def create_input_placeholders() -> typing.List[InferInput]:
+                return [
+                    InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
+
+        else:
+            from tritonclient.http import InferenceServerClient, InferInput
+
+            self.client = InferenceServerClient(parsed_url.netloc)  # Triton HTTP client
+            model_repository = self.client.get_model_repository_index()
+            self.model_name = model_repository[0]['name']
+            self.metadata = self.client.get_model_metadata(self.model_name)
+
+            def create_input_placeholders() -> typing.List[InferInput]:
+                return [
+                    InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
+
+        self._create_input_placeholders_fn = create_input_placeholders
+
+    @property
+    def runtime(self):
+        """Returns the model runtime"""
+        return self.metadata.get('backend', self.metadata.get('platform'))
+
+    def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
+        """ Invokes the model. Parameters can be provided via args or kwargs.
+        args, if provided, are assumed to match the order of inputs of the model.
+        kwargs are matched with the model input names.
+        """
+        inputs = self._create_inputs(*args, **kwargs)
+        response = self.client.infer(model_name=self.model_name, inputs=inputs)
+        result = []
+        for output in self.metadata['outputs']:
+            tensor = torch.as_tensor(response.as_numpy(output['name']))
+            result.append(tensor)
+        return result[0] if len(result) == 1 else result
+
+    def _create_inputs(self, *args, **kwargs):
+        args_len, kwargs_len = len(args), len(kwargs)
+        if not args_len and not kwargs_len:
+            raise RuntimeError('No inputs provided.')
+        if args_len and kwargs_len:
+            raise RuntimeError('Cannot specify args and kwargs at the same time')
+
+        placeholders = self._create_input_placeholders_fn()
+        if args_len:
+            if args_len != len(placeholders):
+                raise RuntimeError(f'Expected {len(placeholders)} inputs, got {args_len}.')
+            for input, value in zip(placeholders, args):
+                input.set_data_from_numpy(value.cpu().numpy())
+        else:
+            for input in placeholders:
+                value = kwargs[input.name]
+                input.set_data_from_numpy(value.cpu().numpy())
+        return placeholders