You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
3.7 KiB
Python
91 lines
3.7 KiB
Python
3 weeks ago
|
# Ultralytics YOLOv5 🚀, AGPL-3.0 license
|
||
|
"""Utils to interact with the Triton Inference Server."""
|
||
|
|
||
|
import typing
|
||
|
from urllib.parse import urlparse
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class TritonRemoteModel:
|
||
|
"""
|
||
|
A wrapper over a model served by the Triton Inference Server.
|
||
|
|
||
|
It can be configured to communicate over GRPC or HTTP. It accepts Torch Tensors as input and returns them as
|
||
|
outputs.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, url: str):
|
||
|
"""
|
||
|
Keyword Arguments:
|
||
|
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000.
|
||
|
"""
|
||
|
parsed_url = urlparse(url)
|
||
|
if parsed_url.scheme == "grpc":
|
||
|
from tritonclient.grpc import InferenceServerClient, InferInput
|
||
|
|
||
|
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
|
||
|
model_repository = self.client.get_model_repository_index()
|
||
|
self.model_name = model_repository.models[0].name
|
||
|
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
|
||
|
|
||
|
def create_input_placeholders() -> typing.List[InferInput]:
|
||
|
return [
|
||
|
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
|
||
|
]
|
||
|
|
||
|
else:
|
||
|
from tritonclient.http import InferenceServerClient, InferInput
|
||
|
|
||
|
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
|
||
|
model_repository = self.client.get_model_repository_index()
|
||
|
self.model_name = model_repository[0]["name"]
|
||
|
self.metadata = self.client.get_model_metadata(self.model_name)
|
||
|
|
||
|
def create_input_placeholders() -> typing.List[InferInput]:
|
||
|
return [
|
||
|
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
|
||
|
]
|
||
|
|
||
|
self._create_input_placeholders_fn = create_input_placeholders
|
||
|
|
||
|
@property
|
||
|
def runtime(self):
|
||
|
"""Returns the model runtime."""
|
||
|
return self.metadata.get("backend", self.metadata.get("platform"))
|
||
|
|
||
|
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
|
||
|
"""
|
||
|
Invokes the model.
|
||
|
|
||
|
Parameters can be provided via args or kwargs. args, if provided, are assumed to match the order of inputs of
|
||
|
the model. kwargs are matched with the model input names.
|
||
|
"""
|
||
|
inputs = self._create_inputs(*args, **kwargs)
|
||
|
response = self.client.infer(model_name=self.model_name, inputs=inputs)
|
||
|
result = []
|
||
|
for output in self.metadata["outputs"]:
|
||
|
tensor = torch.as_tensor(response.as_numpy(output["name"]))
|
||
|
result.append(tensor)
|
||
|
return result[0] if len(result) == 1 else result
|
||
|
|
||
|
def _create_inputs(self, *args, **kwargs):
|
||
|
"""Creates input tensors from args or kwargs, not both; raises error if none or both are provided."""
|
||
|
args_len, kwargs_len = len(args), len(kwargs)
|
||
|
if not args_len and not kwargs_len:
|
||
|
raise RuntimeError("No inputs provided.")
|
||
|
if args_len and kwargs_len:
|
||
|
raise RuntimeError("Cannot specify args and kwargs at the same time")
|
||
|
|
||
|
placeholders = self._create_input_placeholders_fn()
|
||
|
if args_len:
|
||
|
if args_len != len(placeholders):
|
||
|
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.")
|
||
|
for input, value in zip(placeholders, args):
|
||
|
input.set_data_from_numpy(value.cpu().numpy())
|
||
|
else:
|
||
|
for input in placeholders:
|
||
|
value = kwargs[input.name]
|
||
|
input.set_data_from_numpy(value.cpu().numpy())
|
||
|
return placeholders
|