diff options
Diffstat (limited to 'python/pyarmnn/examples/common/network_executor_tflite.py')
-rw-r--r-- | python/pyarmnn/examples/common/network_executor_tflite.py | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/common/network_executor_tflite.py b/python/pyarmnn/examples/common/network_executor_tflite.py new file mode 100644 index 0000000000..10f5e6e6fb --- /dev/null +++ b/python/pyarmnn/examples/common/network_executor_tflite.py @@ -0,0 +1,98 @@ +# Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +# SPDX-License-Identifier: MIT + +import os +from typing import List, Tuple + +import numpy as np +from tflite_runtime import interpreter as tflite + +class TFLiteNetworkExecutor: + + def __init__(self, model_file: str, backends: list, delegate_path: str): + """ + Creates an inference executor for a given network and a list of backends. + + Args: + model_file: User-specified model file. + backends: List of backends to optimize network. + delegate_path: tflite delegate file path (.so). + """ + self.model_file = model_file + self.backends = backends + self.delegate_path = delegate_path + self.interpreter, self.input_details, self.output_details = self.create_network() + + def run(self, input_data_list: list) -> List[np.ndarray]: + """ + Executes inference for the loaded network. + + Args: + input_data_list: List of input frames. + + Returns: + list: Inference results as a list of ndarrays. + """ + output = [] + for index, input_data in enumerate(input_data_list): + self.interpreter.set_tensor(self.input_details[index]['index'], input_data) + self.interpreter.invoke() + for curr_output in self.output_details: + output.append(self.interpreter.get_tensor(curr_output['index'])) + + return output + + def create_network(self): + """ + Creates a network based on the model file and a list of backends. + + Returns: + interpreter: A TensorFlow Lite object for executing inference. + input_details: Contains essential information about the model input. + output_details: Used to map output tensor and its memory. + """ + + # Controls whether optimizations are used or not. + # Please note that optimizations can improve performance in some cases, but it can also + # degrade the performance in other cases. Accuracy might also be affected. + + optimization_enable = "true" + + if not os.path.exists(self.model_file): + raise FileNotFoundError(f'Model file not found for: {self.model_file}') + + _, ext = os.path.splitext(self.model_file) + if ext == '.tflite': + armnn_delegate = tflite.load_delegate(library=self.delegate_path, + options={"backends": ','.join(self.backends), "logging-severity": "info", + "enable-fast-math": optimization_enable, + "reduce-fp32-to-fp16": optimization_enable}) + interpreter = tflite.Interpreter(model_path=self.model_file, + experimental_delegates=[armnn_delegate]) + interpreter.allocate_tensors() + else: + raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") + + # Get input and output binding information + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + return interpreter, input_details, output_details + + def get_data_type(self): + """ + Get the input data type of the initiated network. + + Returns: + numpy data type or None if doesn't exist in the if condition. + """ + return self.input_details[0]['dtype'] + + def get_shape(self): + """ + Get the input shape of the initiated network. + + Returns: + tuple: The Shape of the network input. + """ + return tuple(self.input_details[0]['shape']) |