# -*- coding: utf-8 -*-
""" Model runners. """
from typing import Any, Dict, Generator, List, Tuple, Union
import onnxruntime as rt
from overrides import overrides
from orkgnlp.common.service.base import ORKGNLPBaseRunner
[docs]
class ORKGNLPONNXRunner(ORKGNLPBaseRunner):
"""
The ORKGNLPONNXRunner is a runner specialized for ONNX model formats. It requires therefore a model object of type
``onnx``.
"""
def __init__(self, *args):
super().__init__(*args)
[docs]
@overrides(check_signature=False)
def run(
self,
inputs: Tuple[Any],
output_names: List[str] = None,
custom_input_dict: Dict[str, List[Any]] = None,
**kwargs: Any
) -> Tuple[Any, Dict[str, Any]]:
"""
Runs the given model while initiation in evaluation mode and returns its output.
:param inputs: Tuple of model arguments.
:param output_names: List of output names of the ONNX graph. Check your exporting code for further information!
Defaults to None.
:param custom_input_dict: When given, the argument ``inputs`` will be ignored. This argument must have the
following schema: {input_name_0: [input_value_0], ..., input_name_n: [input_value_n]}. Check your exporting
code for further information! Defaults to None.
:type custom_input_dict: Dict[str, List[Any]].
:return: The model output and kwargs.
"""
session = rt.InferenceSession(self._model.SerializeToString())
input_dict = {session.get_inputs()[i].name: [inputs[i]] for i in range(len(inputs))}
output = session.run(output_names, custom_input_dict or input_dict)
return output, kwargs
[docs]
class ORKGNLPTorchRunner(ORKGNLPBaseRunner):
"""
The ORKGNLPTorchRunner is a runner specialized for Torch model formats. It requires therefore a model object of type
``torch``.
"""
def __init__(self, *args):
super().__init__(*args)
[docs]
@overrides(check_signature=False)
def run(
self,
inputs: Union[Any, List[Tuple[Any]], Dict[str, Any], List[Dict[str, Any]]],
multiple_batches: bool = False,
**kwargs: Any
) -> Union[Tuple[Any, Dict[str, Any]], Tuple[Generator[Any, None, None], Dict[str, Any]]]:
"""
Runs the given model while initiation in evaluation mode and returns its output.
:param inputs: Tuple of model arguments or dict of model named arguments.
A list of tuples or a list of dicts in case of batches.
:param multiple_batches: Whether the model is to be executed x times for each input instance or batch, where
x is the length of ``inputs`` list. Note that in this case the model's outputs
will be returned as a python generator. Defaults to False.
:return: The model output as a tuple or list of tuples, and kwargs.
"""
if hasattr(self._model, "eval"):
self._model.eval()
if not multiple_batches:
if isinstance(inputs, dict):
output = self._model(**inputs)
else:
output = self._model(*inputs)
return output, kwargs
def multiple_batch_generator():
for i, batch in enumerate(inputs):
if isinstance(batch, dict):
output = self._model(**batch)
else:
output = self._model(*batch)
yield output
return multiple_batch_generator(), kwargs