diff options
author | Éanna Ó Catháin <eanna.ocathain@arm.com> | 2020-11-16 14:12:11 +0000 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-11-17 12:23:56 +0000 |
commit | 145c88f851d12d2cadc2f080d232c1d5963d6e47 (patch) | |
tree | 6ae197d74782cd2c7ef8965f4b36acabc65ce453 /python/pyarmnn/examples/common/network_executor.py | |
parent | aa41d5d2f43790938f3a32586626be5ef55b6ca9 (diff) | |
download | armnn-145c88f851d12d2cadc2f080d232c1d5963d6e47.tar.gz |
MLECO-1253 Adding ASR sample application using the PyArmNN api
Change-Id: I450b23800ca316a5bfd4608c8559cf4f11271c21
Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'python/pyarmnn/examples/common/network_executor.py')
-rw-r--r-- | python/pyarmnn/examples/common/network_executor.py | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/common/network_executor.py b/python/pyarmnn/examples/common/network_executor.py new file mode 100644 index 0000000000..6e2c53c43d --- /dev/null +++ b/python/pyarmnn/examples/common/network_executor.py @@ -0,0 +1,108 @@ +# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +# SPDX-License-Identifier: MIT + +import os +from typing import List, Tuple + +import pyarmnn as ann +import numpy as np + + +def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()): + """ + Creates a network based on the model file and a list of backends. + + Args: + model_file: User-specified model file. + backends: List of backends to optimize network. + input_names: + output_names: + + Returns: + net_id: Unique ID of the network to run. + runtime: Runtime context for executing inference. + input_binding_info: Contains essential information about the model input. + output_binding_info: Used to map output tensor and its memory. + """ + if not os.path.exists(model_file): + raise FileNotFoundError(f'Model file not found for: {model_file}') + + _, ext = os.path.splitext(model_file) + if ext == '.tflite': + parser = ann.ITfLiteParser() + else: + raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") + + network = parser.CreateNetworkFromBinaryFile(model_file) + + # Specify backends to optimize network + preferred_backends = [] + for b in backends: + preferred_backends.append(ann.BackendId(b)) + + # Select appropriate device context and optimize the network for that device + options = ann.CreationOptions() + runtime = ann.IRuntime(options) + opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), + ann.OptimizerOptions()) + print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n' + f'Optimization warnings: {messages}') + + # Load the optimized network onto the Runtime device + net_id, _ = runtime.LoadNetwork(opt_network) + + # Get input and output binding information + graph_id = parser.GetSubgraphCount() - 1 + input_names = parser.GetSubgraphInputTensorNames(graph_id) + input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) + output_names = parser.GetSubgraphOutputTensorNames(graph_id) + output_binding_info = [] + for output_name in output_names: + out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) + output_binding_info.append(out_bind_info) + return net_id, runtime, input_binding_info, output_binding_info + + +def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]: + """ + Executes inference for the loaded network. + + Args: + input_tensors: The input frame tensor. + output_tensors: The output tensor from output node. + runtime: Runtime context for executing inference. + net_id: Unique ID of the network to run. + + Returns: + list: Inference results as a list of ndarrays. + """ + runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) + output = ann.workload_tensors_to_ndarray(output_tensors) + return output + + +class ArmnnNetworkExecutor: + + def __init__(self, model_file: str, backends: list): + """ + Creates an inference executor for a given network and a list of backends. + + Args: + model_file: User-specified model file. + backends: List of backends to optimize network. + """ + self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file, + backends) + self.output_tensors = ann.make_output_tensors(self.output_binding_info) + + def run(self, input_tensors: list) -> List[np.ndarray]: + """ + Executes inference for the loaded network. + + Args: + input_tensors: The input frame tensor. + + Returns: + list: Inference results as a list of ndarrays. + """ + return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id) |