You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
3.7 KiB
Python

# Ultralytics YOLOv5 🚀, AGPL-3.0 license
"""Utils to interact with the Triton Inference Server."""
import typing
from urllib.parse import urlparse
import torch
class TritonRemoteModel:
"""
A wrapper over a model served by the Triton Inference Server.
It can be configured to communicate over GRPC or HTTP. It accepts Torch Tensors as input and returns them as
outputs.
"""
def __init__(self, url: str):
"""
Keyword Arguments:
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000.
"""
parsed_url = urlparse(url)
if parsed_url.scheme == "grpc":
from tritonclient.grpc import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository.models[0].name
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
]
else:
from tritonclient.http import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository[0]["name"]
self.metadata = self.client.get_model_metadata(self.model_name)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
]
self._create_input_placeholders_fn = create_input_placeholders
@property
def runtime(self):
"""Returns the model runtime."""
return self.metadata.get("backend", self.metadata.get("platform"))
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
"""
Invokes the model.
Parameters can be provided via args or kwargs. args, if provided, are assumed to match the order of inputs of
the model. kwargs are matched with the model input names.
"""
inputs = self._create_inputs(*args, **kwargs)
response = self.client.infer(model_name=self.model_name, inputs=inputs)
result = []
for output in self.metadata["outputs"]:
tensor = torch.as_tensor(response.as_numpy(output["name"]))
result.append(tensor)
return result[0] if len(result) == 1 else result
def _create_inputs(self, *args, **kwargs):
"""Creates input tensors from args or kwargs, not both; raises error if none or both are provided."""
args_len, kwargs_len = len(args), len(kwargs)
if not args_len and not kwargs_len:
raise RuntimeError("No inputs provided.")
if args_len and kwargs_len:
raise RuntimeError("Cannot specify args and kwargs at the same time")
placeholders = self._create_input_placeholders_fn()
if args_len:
if args_len != len(placeholders):
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.")
for input, value in zip(placeholders, args):
input.set_data_from_numpy(value.cpu().numpy())
else:
for input in placeholders:
value = kwargs[input.name]
input.set_data_from_numpy(value.cpu().numpy())
return placeholders