aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/common/network_executor_tflite.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/common/network_executor_tflite.py')
-rw-r--r--python/pyarmnn/examples/common/network_executor_tflite.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/common/network_executor_tflite.py b/python/pyarmnn/examples/common/network_executor_tflite.py
new file mode 100644
index 0000000000..10f5e6e6fb
--- /dev/null
+++ b/python/pyarmnn/examples/common/network_executor_tflite.py
@@ -0,0 +1,98 @@
+# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+from typing import List, Tuple
+
+import numpy as np
+from tflite_runtime import interpreter as tflite
+
+class TFLiteNetworkExecutor:
+
+ def __init__(self, model_file: str, backends: list, delegate_path: str):
+ """
+ Creates an inference executor for a given network and a list of backends.
+
+ Args:
+ model_file: User-specified model file.
+ backends: List of backends to optimize network.
+ delegate_path: tflite delegate file path (.so).
+ """
+ self.model_file = model_file
+ self.backends = backends
+ self.delegate_path = delegate_path
+ self.interpreter, self.input_details, self.output_details = self.create_network()
+
+ def run(self, input_data_list: list) -> List[np.ndarray]:
+ """
+ Executes inference for the loaded network.
+
+ Args:
+ input_data_list: List of input frames.
+
+ Returns:
+ list: Inference results as a list of ndarrays.
+ """
+ output = []
+ for index, input_data in enumerate(input_data_list):
+ self.interpreter.set_tensor(self.input_details[index]['index'], input_data)
+ self.interpreter.invoke()
+ for curr_output in self.output_details:
+ output.append(self.interpreter.get_tensor(curr_output['index']))
+
+ return output
+
+ def create_network(self):
+ """
+ Creates a network based on the model file and a list of backends.
+
+ Returns:
+ interpreter: A TensorFlow Lite object for executing inference.
+ input_details: Contains essential information about the model input.
+ output_details: Used to map output tensor and its memory.
+ """
+
+ # Controls whether optimizations are used or not.
+ # Please note that optimizations can improve performance in some cases, but it can also
+ # degrade the performance in other cases. Accuracy might also be affected.
+
+ optimization_enable = "true"
+
+ if not os.path.exists(self.model_file):
+ raise FileNotFoundError(f'Model file not found for: {self.model_file}')
+
+ _, ext = os.path.splitext(self.model_file)
+ if ext == '.tflite':
+ armnn_delegate = tflite.load_delegate(library=self.delegate_path,
+ options={"backends": ','.join(self.backends), "logging-severity": "info",
+ "enable-fast-math": optimization_enable,
+ "reduce-fp32-to-fp16": optimization_enable})
+ interpreter = tflite.Interpreter(model_path=self.model_file,
+ experimental_delegates=[armnn_delegate])
+ interpreter.allocate_tensors()
+ else:
+ raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
+
+ # Get input and output binding information
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ return interpreter, input_details, output_details
+
+ def get_data_type(self):
+ """
+ Get the input data type of the initiated network.
+
+ Returns:
+ numpy data type or None if doesn't exist in the if condition.
+ """
+ return self.input_details[0]['dtype']
+
+ def get_shape(self):
+ """
+ Get the input shape of the initiated network.
+
+ Returns:
+ tuple: The Shape of the network input.
+ """
+ return tuple(self.input_details[0]['shape'])