diff options
Diffstat (limited to 'python/pyarmnn/examples/common/network_executor.py')
-rw-r--r-- | python/pyarmnn/examples/common/network_executor.py | 213 |
1 files changed, 133 insertions, 80 deletions
diff --git a/python/pyarmnn/examples/common/network_executor.py b/python/pyarmnn/examples/common/network_executor.py index 6e2c53c43d..72262fc520 100644 --- a/python/pyarmnn/examples/common/network_executor.py +++ b/python/pyarmnn/examples/common/network_executor.py @@ -7,80 +7,6 @@ from typing import List, Tuple import pyarmnn as ann import numpy as np - -def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()): - """ - Creates a network based on the model file and a list of backends. - - Args: - model_file: User-specified model file. - backends: List of backends to optimize network. - input_names: - output_names: - - Returns: - net_id: Unique ID of the network to run. - runtime: Runtime context for executing inference. - input_binding_info: Contains essential information about the model input. - output_binding_info: Used to map output tensor and its memory. - """ - if not os.path.exists(model_file): - raise FileNotFoundError(f'Model file not found for: {model_file}') - - _, ext = os.path.splitext(model_file) - if ext == '.tflite': - parser = ann.ITfLiteParser() - else: - raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") - - network = parser.CreateNetworkFromBinaryFile(model_file) - - # Specify backends to optimize network - preferred_backends = [] - for b in backends: - preferred_backends.append(ann.BackendId(b)) - - # Select appropriate device context and optimize the network for that device - options = ann.CreationOptions() - runtime = ann.IRuntime(options) - opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), - ann.OptimizerOptions()) - print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n' - f'Optimization warnings: {messages}') - - # Load the optimized network onto the Runtime device - net_id, _ = runtime.LoadNetwork(opt_network) - - # Get input and output binding information - graph_id = parser.GetSubgraphCount() - 1 - input_names = parser.GetSubgraphInputTensorNames(graph_id) - input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) - output_names = parser.GetSubgraphOutputTensorNames(graph_id) - output_binding_info = [] - for output_name in output_names: - out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) - output_binding_info.append(out_bind_info) - return net_id, runtime, input_binding_info, output_binding_info - - -def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]: - """ - Executes inference for the loaded network. - - Args: - input_tensors: The input frame tensor. - output_tensors: The output tensor from output node. - runtime: Runtime context for executing inference. - net_id: Unique ID of the network to run. - - Returns: - list: Inference results as a list of ndarrays. - """ - runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) - output = ann.workload_tensors_to_ndarray(output_tensors) - return output - - class ArmnnNetworkExecutor: def __init__(self, model_file: str, backends: list): @@ -91,18 +17,145 @@ class ArmnnNetworkExecutor: model_file: User-specified model file. backends: List of backends to optimize network. """ - self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file, - backends) + self.model_file = model_file + self.backends = backends + self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network() self.output_tensors = ann.make_output_tensors(self.output_binding_info) - def run(self, input_tensors: list) -> List[np.ndarray]: + def run(self, input_data_list: list) -> List[np.ndarray]: """ - Executes inference for the loaded network. + Creates input tensors from input data and executes inference with the loaded network. Args: - input_tensors: The input frame tensor. + input_data_list: List of input frames. Returns: list: Inference results as a list of ndarrays. """ - return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id) + input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list) + self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors) + output = ann.workload_tensors_to_ndarray(self.output_tensors) + + return output + + def create_network(self): + """ + Creates a network based on the model file and a list of backends. + + Returns: + net_id: Unique ID of the network to run. + runtime: Runtime context for executing inference. + input_binding_info: Contains essential information about the model input. + output_binding_info: Used to map output tensor and its memory. + """ + if not os.path.exists(self.model_file): + raise FileNotFoundError(f'Model file not found for: {self.model_file}') + + _, ext = os.path.splitext(self.model_file) + if ext == '.tflite': + parser = ann.ITfLiteParser() + else: + raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") + + network = parser.CreateNetworkFromBinaryFile(self.model_file) + + # Specify backends to optimize network + preferred_backends = [] + for b in self.backends: + preferred_backends.append(ann.BackendId(b)) + + # Select appropriate device context and optimize the network for that device + options = ann.CreationOptions() + runtime = ann.IRuntime(options) + opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), + ann.OptimizerOptions()) + print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n' + f'Optimization warnings: {messages}') + + # Load the optimized network onto the Runtime device + net_id, _ = runtime.LoadNetwork(opt_network) + + # Get input and output binding information + graph_id = parser.GetSubgraphCount() - 1 + input_names = parser.GetSubgraphInputTensorNames(graph_id) + input_binding_info = [] + for input_name in input_names: + in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name) + input_binding_info.append(in_bind_info) + output_names = parser.GetSubgraphOutputTensorNames(graph_id) + output_binding_info = [] + for output_name in output_names: + out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) + output_binding_info.append(out_bind_info) + return net_id, runtime, input_binding_info, output_binding_info + + def get_data_type(self): + """ + Get the input data type of the initiated network. + + Returns: + numpy data type or None if doesn't exist in the if condition. + """ + if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32: + return np.float32 + elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8: + return np.uint8 + elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8: + return np.int8 + else: + return None + + def get_shape(self): + """ + Get the input shape of the initiated network. + + Returns: + tuple: The Shape of the network input. + """ + return tuple(self.input_binding_info[0][1].GetShape()) + + def get_input_quantization_scale(self, idx): + """ + Get the input quantization scale of the initiated network. + + Returns: + The quantization scale of the network input. + """ + return self.input_binding_info[idx][1].GetQuantizationScale() + + def get_input_quantization_offset(self, idx): + """ + Get the input quantization offset of the initiated network. + + Returns: + The quantization offset of the network input. + """ + return self.input_binding_info[idx][1].GetQuantizationOffset() + + def is_output_quantized(self, idx): + """ + Get True/False if output tensor is quantized or not respectively. + + Returns: + True if output is quantized and False otherwise. + """ + return self.output_binding_info[idx][1].IsQuantized() + + def get_output_quantization_scale(self, idx): + """ + Get the output quantization offset of the initiated network. + + Returns: + The quantization offset of the network output. + """ + return self.output_binding_info[idx][1].GetQuantizationScale() + + def get_output_quantization_offset(self, idx): + """ + Get the output quantization offset of the initiated network. + + Returns: + The quantization offset of the network output. + """ + return self.output_binding_info[idx][1].GetQuantizationOffset() + |