aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/common
diff options
context:
space:
mode:
authorRaviv Shalev <raviv.shalev@arm.com>2021-12-07 15:18:09 +0200
committerTeresaARM <teresa.charlinreyes@arm.com>2022-04-13 15:33:31 +0000
commit97ddc06e52fbcabfd8ede7a00e9494c663186b92 (patch)
tree43c84d352c3a67aa45d89760fba6c79b81c8f8dc /python/pyarmnn/examples/common
parent2f0ddb67d8f9267ab600a8a26308cab32f9e16ac (diff)
downloadarmnn-97ddc06e52fbcabfd8ede7a00e9494c663186b92.tar.gz
MLECO-2493 Add python OD example with TFLite delegate
Signed-off-by: Raviv Shalev <raviv.shalev@arm.com> Change-Id: I25fcccbf912be0c5bd4fbfd2e97552341958af35
Diffstat (limited to 'python/pyarmnn/examples/common')
-rw-r--r--python/pyarmnn/examples/common/cv_utils.py58
-rw-r--r--python/pyarmnn/examples/common/network_executor.py213
-rw-r--r--python/pyarmnn/examples/common/network_executor_tflite.py98
-rw-r--r--python/pyarmnn/examples/common/utils.py73
4 files changed, 311 insertions, 131 deletions
diff --git a/python/pyarmnn/examples/common/cv_utils.py b/python/pyarmnn/examples/common/cv_utils.py
index e12ff50548..36d1039227 100644
--- a/python/pyarmnn/examples/common/cv_utils.py
+++ b/python/pyarmnn/examples/common/cv_utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2020-2022 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""
@@ -11,29 +11,35 @@ import os
import cv2
import numpy as np
-import pyarmnn as ann
-
-def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool):
+def preprocess(frame: np.ndarray, input_data_type, input_data_shape: tuple, is_normalised: bool,
+ keep_aspect_ratio: bool=True):
"""
Takes a frame, resizes, swaps channels and converts data type to match
- model input layer. The converted frame is wrapped in a const tensor
- and bound to the input tensor.
+ model input layer.
Args:
frame: Captured frame from video.
- input_binding_info: Contains shape and data type of model input layer.
+ input_data_type: Contains data type of model input layer.
+ input_data_shape: Contains shape of model input layer.
is_normalised: if the input layer expects normalised data
+ keep_aspect_ratio: Network executor's input data aspect ratio
Returns:
Input tensor.
"""
- # Swap channels and resize frame to model resolution
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
+ if keep_aspect_ratio:
+ # Swap channels and resize frame to model resolution
+ resized_frame = resize_with_aspect_ratio(frame, input_data_shape)
+ else:
+ # select the height and width from input_data_shape
+ frame_height = input_data_shape[1]
+ frame_width = input_data_shape[2]
+ resized_frame = cv2.resize(frame, (frame_width, frame_height))
# Expand dimensions and convert data type to match model input
- if input_binding_info[1].GetDataType() == ann.DataType_Float32:
+ if np.float32 == input_data_type:
data_type = np.float32
if is_normalised:
resized_frame = resized_frame.astype("float32")/255
@@ -41,26 +47,24 @@ def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool
data_type = np.uint8
resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
- assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
-
- input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
- return input_tensors
+ assert resized_frame.shape == input_data_shape
+ return resized_frame
-def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
+def resize_with_aspect_ratio(frame: np.ndarray, input_data_shape: tuple):
"""
Resizes frame while maintaining aspect ratio, padding any empty space.
Args:
frame: Captured frame.
- input_binding_info: Contains shape of model input layer.
+ input_data_shape: Contains shape of model input layer.
Returns:
Frame resized to the size of model input layer.
"""
aspect_ratio = frame.shape[1] / frame.shape[0]
- model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
+ _, model_height, model_width, _ = input_data_shape
if aspect_ratio >= 1.0:
new_height, new_width = int(model_width / aspect_ratio), model_width
@@ -173,14 +177,14 @@ def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labe
# Create label for detected object class
label = f'{label} {confidence * 100:.1f}%'
- label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
+ label_color = (0, 0, 0) if sum(color) > 200 else (255, 255, 255)
# Make sure label always stays on-screen
x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
- lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
+ lbl_text_pos = (x_min + 5, y_min + 16 if y_min < 25 else y_min - 5)
# Add label and confidence value
cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
@@ -190,3 +194,19 @@ def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labe
def get_source_encoding_int(video_capture):
return int(video_capture.get(cv2.CAP_PROP_FOURCC))
+
+
+def crop_bounding_box_object(input_frame: np.ndarray, x_min: float, y_min: float, x_max: float, y_max: float):
+ """
+ Creates a cropped image based on x and y coordinates.
+
+ Args:
+ input_frame: Image to crop
+ x_min, y_min, x_max, y_max: Coordinates of the bounding box
+
+ Returns:
+ Cropped image
+ """
+ # Adding +1 to exclude the bounding box pixels.
+ cropped_image = input_frame[int(y_min) + 1:int(y_max), int(x_min) + 1:int(x_max)]
+ return cropped_image
diff --git a/python/pyarmnn/examples/common/network_executor.py b/python/pyarmnn/examples/common/network_executor.py
index 6e2c53c43d..72262fc520 100644
--- a/python/pyarmnn/examples/common/network_executor.py
+++ b/python/pyarmnn/examples/common/network_executor.py
@@ -7,80 +7,6 @@ from typing import List, Tuple
import pyarmnn as ann
import numpy as np
-
-def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()):
- """
- Creates a network based on the model file and a list of backends.
-
- Args:
- model_file: User-specified model file.
- backends: List of backends to optimize network.
- input_names:
- output_names:
-
- Returns:
- net_id: Unique ID of the network to run.
- runtime: Runtime context for executing inference.
- input_binding_info: Contains essential information about the model input.
- output_binding_info: Used to map output tensor and its memory.
- """
- if not os.path.exists(model_file):
- raise FileNotFoundError(f'Model file not found for: {model_file}')
-
- _, ext = os.path.splitext(model_file)
- if ext == '.tflite':
- parser = ann.ITfLiteParser()
- else:
- raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
-
- network = parser.CreateNetworkFromBinaryFile(model_file)
-
- # Specify backends to optimize network
- preferred_backends = []
- for b in backends:
- preferred_backends.append(ann.BackendId(b))
-
- # Select appropriate device context and optimize the network for that device
- options = ann.CreationOptions()
- runtime = ann.IRuntime(options)
- opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
- ann.OptimizerOptions())
- print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n'
- f'Optimization warnings: {messages}')
-
- # Load the optimized network onto the Runtime device
- net_id, _ = runtime.LoadNetwork(opt_network)
-
- # Get input and output binding information
- graph_id = parser.GetSubgraphCount() - 1
- input_names = parser.GetSubgraphInputTensorNames(graph_id)
- input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
- output_names = parser.GetSubgraphOutputTensorNames(graph_id)
- output_binding_info = []
- for output_name in output_names:
- out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
- output_binding_info.append(out_bind_info)
- return net_id, runtime, input_binding_info, output_binding_info
-
-
-def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]:
- """
- Executes inference for the loaded network.
-
- Args:
- input_tensors: The input frame tensor.
- output_tensors: The output tensor from output node.
- runtime: Runtime context for executing inference.
- net_id: Unique ID of the network to run.
-
- Returns:
- list: Inference results as a list of ndarrays.
- """
- runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
- output = ann.workload_tensors_to_ndarray(output_tensors)
- return output
-
-
class ArmnnNetworkExecutor:
def __init__(self, model_file: str, backends: list):
@@ -91,18 +17,145 @@ class ArmnnNetworkExecutor:
model_file: User-specified model file.
backends: List of backends to optimize network.
"""
- self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file,
- backends)
+ self.model_file = model_file
+ self.backends = backends
+ self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network()
self.output_tensors = ann.make_output_tensors(self.output_binding_info)
- def run(self, input_tensors: list) -> List[np.ndarray]:
+ def run(self, input_data_list: list) -> List[np.ndarray]:
"""
- Executes inference for the loaded network.
+ Creates input tensors from input data and executes inference with the loaded network.
Args:
- input_tensors: The input frame tensor.
+ input_data_list: List of input frames.
Returns:
list: Inference results as a list of ndarrays.
"""
- return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id)
+ input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list)
+ self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors)
+ output = ann.workload_tensors_to_ndarray(self.output_tensors)
+
+ return output
+
+ def create_network(self):
+ """
+ Creates a network based on the model file and a list of backends.
+
+ Returns:
+ net_id: Unique ID of the network to run.
+ runtime: Runtime context for executing inference.
+ input_binding_info: Contains essential information about the model input.
+ output_binding_info: Used to map output tensor and its memory.
+ """
+ if not os.path.exists(self.model_file):
+ raise FileNotFoundError(f'Model file not found for: {self.model_file}')
+
+ _, ext = os.path.splitext(self.model_file)
+ if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+ else:
+ raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
+
+ network = parser.CreateNetworkFromBinaryFile(self.model_file)
+
+ # Specify backends to optimize network
+ preferred_backends = []
+ for b in self.backends:
+ preferred_backends.append(ann.BackendId(b))
+
+ # Select appropriate device context and optimize the network for that device
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+ ann.OptimizerOptions())
+ print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n'
+ f'Optimization warnings: {messages}')
+
+ # Load the optimized network onto the Runtime device
+ net_id, _ = runtime.LoadNetwork(opt_network)
+
+ # Get input and output binding information
+ graph_id = parser.GetSubgraphCount() - 1
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ input_binding_info = []
+ for input_name in input_names:
+ in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name)
+ input_binding_info.append(in_bind_info)
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+ output_binding_info = []
+ for output_name in output_names:
+ out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+ output_binding_info.append(out_bind_info)
+ return net_id, runtime, input_binding_info, output_binding_info
+
+ def get_data_type(self):
+ """
+ Get the input data type of the initiated network.
+
+ Returns:
+ numpy data type or None if doesn't exist in the if condition.
+ """
+ if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32:
+ return np.float32
+ elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8:
+ return np.uint8
+ elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8:
+ return np.int8
+ else:
+ return None
+
+ def get_shape(self):
+ """
+ Get the input shape of the initiated network.
+
+ Returns:
+ tuple: The Shape of the network input.
+ """
+ return tuple(self.input_binding_info[0][1].GetShape())
+
+ def get_input_quantization_scale(self, idx):
+ """
+ Get the input quantization scale of the initiated network.
+
+ Returns:
+ The quantization scale of the network input.
+ """
+ return self.input_binding_info[idx][1].GetQuantizationScale()
+
+ def get_input_quantization_offset(self, idx):
+ """
+ Get the input quantization offset of the initiated network.
+
+ Returns:
+ The quantization offset of the network input.
+ """
+ return self.input_binding_info[idx][1].GetQuantizationOffset()
+
+ def is_output_quantized(self, idx):
+ """
+ Get True/False if output tensor is quantized or not respectively.
+
+ Returns:
+ True if output is quantized and False otherwise.
+ """
+ return self.output_binding_info[idx][1].IsQuantized()
+
+ def get_output_quantization_scale(self, idx):
+ """
+ Get the output quantization offset of the initiated network.
+
+ Returns:
+ The quantization offset of the network output.
+ """
+ return self.output_binding_info[idx][1].GetQuantizationScale()
+
+ def get_output_quantization_offset(self, idx):
+ """
+ Get the output quantization offset of the initiated network.
+
+ Returns:
+ The quantization offset of the network output.
+ """
+ return self.output_binding_info[idx][1].GetQuantizationOffset()
+
diff --git a/python/pyarmnn/examples/common/network_executor_tflite.py b/python/pyarmnn/examples/common/network_executor_tflite.py
new file mode 100644
index 0000000000..10f5e6e6fb
--- /dev/null
+++ b/python/pyarmnn/examples/common/network_executor_tflite.py
@@ -0,0 +1,98 @@
+# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+from typing import List, Tuple
+
+import numpy as np
+from tflite_runtime import interpreter as tflite
+
+class TFLiteNetworkExecutor:
+
+ def __init__(self, model_file: str, backends: list, delegate_path: str):
+ """
+ Creates an inference executor for a given network and a list of backends.
+
+ Args:
+ model_file: User-specified model file.
+ backends: List of backends to optimize network.
+ delegate_path: tflite delegate file path (.so).
+ """
+ self.model_file = model_file
+ self.backends = backends
+ self.delegate_path = delegate_path
+ self.interpreter, self.input_details, self.output_details = self.create_network()
+
+ def run(self, input_data_list: list) -> List[np.ndarray]:
+ """
+ Executes inference for the loaded network.
+
+ Args:
+ input_data_list: List of input frames.
+
+ Returns:
+ list: Inference results as a list of ndarrays.
+ """
+ output = []
+ for index, input_data in enumerate(input_data_list):
+ self.interpreter.set_tensor(self.input_details[index]['index'], input_data)
+ self.interpreter.invoke()
+ for curr_output in self.output_details:
+ output.append(self.interpreter.get_tensor(curr_output['index']))
+
+ return output
+
+ def create_network(self):
+ """
+ Creates a network based on the model file and a list of backends.
+
+ Returns:
+ interpreter: A TensorFlow Lite object for executing inference.
+ input_details: Contains essential information about the model input.
+ output_details: Used to map output tensor and its memory.
+ """
+
+ # Controls whether optimizations are used or not.
+ # Please note that optimizations can improve performance in some cases, but it can also
+ # degrade the performance in other cases. Accuracy might also be affected.
+
+ optimization_enable = "true"
+
+ if not os.path.exists(self.model_file):
+ raise FileNotFoundError(f'Model file not found for: {self.model_file}')
+
+ _, ext = os.path.splitext(self.model_file)
+ if ext == '.tflite':
+ armnn_delegate = tflite.load_delegate(library=self.delegate_path,
+ options={"backends": ','.join(self.backends), "logging-severity": "info",
+ "enable-fast-math": optimization_enable,
+ "reduce-fp32-to-fp16": optimization_enable})
+ interpreter = tflite.Interpreter(model_path=self.model_file,
+ experimental_delegates=[armnn_delegate])
+ interpreter.allocate_tensors()
+ else:
+ raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
+
+ # Get input and output binding information
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ return interpreter, input_details, output_details
+
+ def get_data_type(self):
+ """
+ Get the input data type of the initiated network.
+
+ Returns:
+ numpy data type or None if doesn't exist in the if condition.
+ """
+ return self.input_details[0]['dtype']
+
+ def get_shape(self):
+ """
+ Get the input shape of the initiated network.
+
+ Returns:
+ tuple: The Shape of the network input.
+ """
+ return tuple(self.input_details[0]['shape'])
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py
index d4dadf80a4..beca0d37a0 100644
--- a/python/pyarmnn/examples/common/utils.py
+++ b/python/pyarmnn/examples/common/utils.py
@@ -8,7 +8,7 @@ import errno
from pathlib import Path
import numpy as np
-import pyarmnn as ann
+import datetime
def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
@@ -42,67 +42,76 @@ def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
return labels
-def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
+def prepare_input_data(audio_data, input_data_type, input_quant_scale, input_quant_offset, mfcc_preprocessor):
"""
Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
input tensors.
Args:
audio_data: The audio data to process
- mfcc_instance: the mfcc class instance
- input_binding_info: the model input binding info
- mfcc_preprocessor: the mfcc preprocessor instance
+ mfcc_instance: The mfcc class instance
+ input_data_type: The model's input data type
+ input_quant_scale: The model's quantization scale
+ input_quant_offset: The model's quantization offset
+ mfcc_preprocessor: The mfcc preprocessor instance
Returns:
- input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
+ input_data: The prepared input data
"""
- data_type = input_binding_info[1].GetDataType()
- input_tensor = mfcc_preprocessor.extract_features(audio_data)
- if data_type != ann.DataType_Float32:
- input_tensor = quantize_input(input_tensor, input_binding_info)
- input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
- return input_tensors
+ input_data = mfcc_preprocessor.extract_features(audio_data)
+ if input_data_type != np.float32:
+ input_data = quantize_input(input_data, input_data_type, input_quant_scale, input_quant_offset)
+ return input_data
-def quantize_input(data, input_binding_info):
+def quantize_input(data, input_data_type, input_quant_scale, input_quant_offset):
"""Quantize the float input to (u)int8 ready for inputting to model."""
if data.ndim != 2:
raise RuntimeError("Audio data must have 2 dimensions for quantization")
- quant_scale = input_binding_info[1].GetQuantizationScale()
- quant_offset = input_binding_info[1].GetQuantizationOffset()
- data_type = input_binding_info[1].GetDataType()
-
- if data_type == ann.DataType_QAsymmS8:
- data_type = np.int8
- elif data_type == ann.DataType_QAsymmU8:
- data_type = np.uint8
- else:
+ if (input_data_type != np.int8) and (input_data_type != np.uint8):
raise ValueError("Could not quantize data to required data type")
- d_min = np.iinfo(data_type).min
- d_max = np.iinfo(data_type).max
+ d_min = np.iinfo(input_data_type).min
+ d_max = np.iinfo(input_data_type).max
for row in range(data.shape[0]):
for col in range(data.shape[1]):
- data[row, col] = (data[row, col] / quant_scale) + quant_offset
+ data[row, col] = (data[row, col] / input_quant_scale) + input_quant_offset
data[row, col] = np.clip(data[row, col], d_min, d_max)
- data = data.astype(data_type)
+ data = data.astype(input_data_type)
return data
-def dequantize_output(data, output_binding_info):
+def dequantize_output(data, is_output_quantized, output_quant_scale, output_quant_offset):
"""Dequantize the (u)int8 output to float"""
- if output_binding_info[1].IsQuantized():
+ if is_output_quantized:
if data.ndim != 2:
raise RuntimeError("Data must have 2 dimensions for quantization")
- quant_scale = output_binding_info[1].GetQuantizationScale()
- quant_offset = output_binding_info[1].GetQuantizationOffset()
-
data = data.astype(float)
for row in range(data.shape[0]):
for col in range(data.shape[1]):
- data[row, col] = (data[row, col] - quant_offset)*quant_scale
+ data[row, col] = (data[row, col] - output_quant_offset)*output_quant_scale
return data
+
+
+class Profiling:
+ def __init__(self, enabled: bool):
+ self.m_start = 0
+ self.m_end = 0
+ self.m_enabled = enabled
+
+ def profiling_start(self):
+ if self.m_enabled:
+ self.m_start = datetime.datetime.now()
+
+ def profiling_stop_and_print_us(self, msg):
+ if self.m_enabled:
+ self.m_end = datetime.datetime.now()
+ period = self.m_end - self.m_start
+ period_us = period.seconds * 1_000_000 + period.microseconds
+ print(f'Profiling: {msg} : {period_us:,} microSeconds')
+ return period_us
+ return 0