|
a |
|
b/utils/triton.py |
|
|
1 |
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license |
|
|
2 |
""" Utils to interact with the Triton Inference Server |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import typing |
|
|
6 |
from urllib.parse import urlparse |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class TritonRemoteModel: |
|
|
12 |
""" A wrapper over a model served by the Triton Inference Server. It can |
|
|
13 |
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors |
|
|
14 |
as input and returns them as outputs. |
|
|
15 |
""" |
|
|
16 |
|
|
|
17 |
def __init__(self, url: str): |
|
|
18 |
""" |
|
|
19 |
Keyword arguments: |
|
|
20 |
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000 |
|
|
21 |
""" |
|
|
22 |
|
|
|
23 |
parsed_url = urlparse(url) |
|
|
24 |
if parsed_url.scheme == 'grpc': |
|
|
25 |
from tritonclient.grpc import InferenceServerClient, InferInput |
|
|
26 |
|
|
|
27 |
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client |
|
|
28 |
model_repository = self.client.get_model_repository_index() |
|
|
29 |
self.model_name = model_repository.models[0].name |
|
|
30 |
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True) |
|
|
31 |
|
|
|
32 |
def create_input_placeholders() -> typing.List[InferInput]: |
|
|
33 |
return [ |
|
|
34 |
InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']] |
|
|
35 |
|
|
|
36 |
else: |
|
|
37 |
from tritonclient.http import InferenceServerClient, InferInput |
|
|
38 |
|
|
|
39 |
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client |
|
|
40 |
model_repository = self.client.get_model_repository_index() |
|
|
41 |
self.model_name = model_repository[0]['name'] |
|
|
42 |
self.metadata = self.client.get_model_metadata(self.model_name) |
|
|
43 |
|
|
|
44 |
def create_input_placeholders() -> typing.List[InferInput]: |
|
|
45 |
return [ |
|
|
46 |
InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']] |
|
|
47 |
|
|
|
48 |
self._create_input_placeholders_fn = create_input_placeholders |
|
|
49 |
|
|
|
50 |
@property |
|
|
51 |
def runtime(self): |
|
|
52 |
"""Returns the model runtime""" |
|
|
53 |
return self.metadata.get('backend', self.metadata.get('platform')) |
|
|
54 |
|
|
|
55 |
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]: |
|
|
56 |
""" Invokes the model. Parameters can be provided via args or kwargs. |
|
|
57 |
args, if provided, are assumed to match the order of inputs of the model. |
|
|
58 |
kwargs are matched with the model input names. |
|
|
59 |
""" |
|
|
60 |
inputs = self._create_inputs(*args, **kwargs) |
|
|
61 |
response = self.client.infer(model_name=self.model_name, inputs=inputs) |
|
|
62 |
result = [] |
|
|
63 |
for output in self.metadata['outputs']: |
|
|
64 |
tensor = torch.as_tensor(response.as_numpy(output['name'])) |
|
|
65 |
result.append(tensor) |
|
|
66 |
return result[0] if len(result) == 1 else result |
|
|
67 |
|
|
|
68 |
def _create_inputs(self, *args, **kwargs): |
|
|
69 |
args_len, kwargs_len = len(args), len(kwargs) |
|
|
70 |
if not args_len and not kwargs_len: |
|
|
71 |
raise RuntimeError('No inputs provided.') |
|
|
72 |
if args_len and kwargs_len: |
|
|
73 |
raise RuntimeError('Cannot specify args and kwargs at the same time') |
|
|
74 |
|
|
|
75 |
placeholders = self._create_input_placeholders_fn() |
|
|
76 |
if args_len: |
|
|
77 |
if args_len != len(placeholders): |
|
|
78 |
raise RuntimeError(f'Expected {len(placeholders)} inputs, got {args_len}.') |
|
|
79 |
for input, value in zip(placeholders, args): |
|
|
80 |
input.set_data_from_numpy(value.cpu().numpy()) |
|
|
81 |
else: |
|
|
82 |
for input in placeholders: |
|
|
83 |
value = kwargs[input.name] |
|
|
84 |
input.set_data_from_numpy(value.cpu().numpy()) |
|
|
85 |
return placeholders |