# Copyright © 2022 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT import os from typing import List, Tuple import numpy as np from tflite_runtime import interpreter as tflite class TFLiteNetworkExecutor: def __init__(self, model_file: str, backends: list, delegate_path: str): """ Creates an inference executor for a given network and a list of backends. Args: model_file: User-specified model file. backends: List of backends to optimize network. delegate_path: tflite delegate file path (.so). """ self.model_file = model_file self.backends = backends self.delegate_path = delegate_path self.interpreter, self.input_details, self.output_details = self.create_network() def run(self, input_data_list: list) -> List[np.ndarray]: """ Executes inference for the loaded network. Args: input_data_list: List of input frames. Returns: list: Inference results as a list of ndarrays. """ output = [] for index, input_data in enumerate(input_data_list): self.interpreter.set_tensor(self.input_details[index]['index'], input_data) self.interpreter.invoke() for curr_output in self.output_details: output.append(self.interpreter.get_tensor(curr_output['index'])) return output def create_network(self): """ Creates a network based on the model file and a list of backends. Returns: interpreter: A TensorFlow Lite object for executing inference. input_details: Contains essential information about the model input. output_details: Used to map output tensor and its memory. """ # Controls whether optimizations are used or not. # Please note that optimizations can improve performance in some cases, but it can also # degrade the performance in other cases. Accuracy might also be affected. optimization_enable = "true" if not os.path.exists(self.model_file): raise FileNotFoundError(f'Model file not found for: {self.model_file}') _, ext = os.path.splitext(self.model_file) if ext == '.tflite': armnn_delegate = tflite.load_delegate(library=self.delegate_path, options={"backends": ','.join(self.backends), "logging-severity": "info", "enable-fast-math": optimization_enable, "reduce-fp32-to-fp16": optimization_enable}) interpreter = tflite.Interpreter(model_path=self.model_file, experimental_delegates=[armnn_delegate]) interpreter.allocate_tensors() else: raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") # Get input and output binding information input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() return interpreter, input_details, output_details def get_data_type(self): """ Get the input data type of the initiated network. Returns: numpy data type or None if doesn't exist in the if condition. """ return self.input_details[0]['dtype'] def get_shape(self): """ Get the input shape of the initiated network. Returns: tuple: The Shape of the network input. """ return tuple(self.input_details[0]['shape'])