From 97ddc06e52fbcabfd8ede7a00e9494c663186b92 Mon Sep 17 00:00:00 2001 From: Raviv Shalev Date: Tue, 7 Dec 2021 15:18:09 +0200 Subject: MLECO-2493 Add python OD example with TFLite delegate Signed-off-by: Raviv Shalev Change-Id: I25fcccbf912be0c5bd4fbfd2e97552341958af35 --- python/pyarmnn/examples/common/cv_utils.py | 58 ++++-- python/pyarmnn/examples/common/network_executor.py | 213 +++++++++++++-------- .../examples/common/network_executor_tflite.py | 98 ++++++++++ python/pyarmnn/examples/common/utils.py | 73 +++---- python/pyarmnn/examples/keyword_spotting/README.MD | 2 +- .../keyword_spotting/run_audio_classification.py | 11 +- python/pyarmnn/examples/object_detection/README.md | 126 +++++++++++- .../examples/object_detection/requirements.txt | 2 + .../examples/object_detection/run_video_file.py | 84 ++++++-- .../examples/object_detection/run_video_stream.py | 82 ++++++-- .../examples/object_detection/style_transfer.py | 138 +++++++++++++ python/pyarmnn/examples/object_detection/yolo.py | 6 +- .../pyarmnn/examples/speech_recognition/README.md | 4 +- .../examples/speech_recognition/run_audio_file.py | 7 +- python/pyarmnn/examples/tests/conftest.py | 20 +- python/pyarmnn/examples/tests/context.py | 6 +- python/pyarmnn/examples/tests/test_common_utils.py | 23 +++ .../examples/tests/test_network_executor.py | 24 ++- .../pyarmnn/examples/tests/test_style_transfer.py | 70 +++++++ 19 files changed, 858 insertions(+), 189 deletions(-) create mode 100644 python/pyarmnn/examples/common/network_executor_tflite.py create mode 100644 python/pyarmnn/examples/object_detection/style_transfer.py create mode 100644 python/pyarmnn/examples/tests/test_style_transfer.py diff --git a/python/pyarmnn/examples/common/cv_utils.py b/python/pyarmnn/examples/common/cv_utils.py index e12ff50548..36d1039227 100644 --- a/python/pyarmnn/examples/common/cv_utils.py +++ b/python/pyarmnn/examples/common/cv_utils.py @@ -1,4 +1,4 @@ -# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved. +# Copyright © 2020-2022 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT """ @@ -11,29 +11,35 @@ import os import cv2 import numpy as np -import pyarmnn as ann - -def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool): +def preprocess(frame: np.ndarray, input_data_type, input_data_shape: tuple, is_normalised: bool, + keep_aspect_ratio: bool=True): """ Takes a frame, resizes, swaps channels and converts data type to match - model input layer. The converted frame is wrapped in a const tensor - and bound to the input tensor. + model input layer. Args: frame: Captured frame from video. - input_binding_info: Contains shape and data type of model input layer. + input_data_type: Contains data type of model input layer. + input_data_shape: Contains shape of model input layer. is_normalised: if the input layer expects normalised data + keep_aspect_ratio: Network executor's input data aspect ratio Returns: Input tensor. """ - # Swap channels and resize frame to model resolution frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - resized_frame = resize_with_aspect_ratio(frame, input_binding_info) + if keep_aspect_ratio: + # Swap channels and resize frame to model resolution + resized_frame = resize_with_aspect_ratio(frame, input_data_shape) + else: + # select the height and width from input_data_shape + frame_height = input_data_shape[1] + frame_width = input_data_shape[2] + resized_frame = cv2.resize(frame, (frame_width, frame_height)) # Expand dimensions and convert data type to match model input - if input_binding_info[1].GetDataType() == ann.DataType_Float32: + if np.float32 == input_data_type: data_type = np.float32 if is_normalised: resized_frame = resized_frame.astype("float32")/255 @@ -41,26 +47,24 @@ def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool data_type = np.uint8 resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0) - assert resized_frame.shape == tuple(input_binding_info[1].GetShape()) - - input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame]) - return input_tensors + assert resized_frame.shape == input_data_shape + return resized_frame -def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple): +def resize_with_aspect_ratio(frame: np.ndarray, input_data_shape: tuple): """ Resizes frame while maintaining aspect ratio, padding any empty space. Args: frame: Captured frame. - input_binding_info: Contains shape of model input layer. + input_data_shape: Contains shape of model input layer. Returns: Frame resized to the size of model input layer. """ aspect_ratio = frame.shape[1] / frame.shape[0] - model_height, model_width = list(input_binding_info[1].GetShape())[1:3] + _, model_height, model_width, _ = input_data_shape if aspect_ratio >= 1.0: new_height, new_width = int(model_width / aspect_ratio), model_width @@ -173,14 +177,14 @@ def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labe # Create label for detected object class label = f'{label} {confidence * 100:.1f}%' - label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255) + label_color = (0, 0, 0) if sum(color) > 200 else (255, 255, 255) # Make sure label always stays on-screen x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2] lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text) lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min) - lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5) + lbl_text_pos = (x_min + 5, y_min + 16 if y_min < 25 else y_min - 5) # Add label and confidence value cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1) @@ -190,3 +194,19 @@ def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labe def get_source_encoding_int(video_capture): return int(video_capture.get(cv2.CAP_PROP_FOURCC)) + + +def crop_bounding_box_object(input_frame: np.ndarray, x_min: float, y_min: float, x_max: float, y_max: float): + """ + Creates a cropped image based on x and y coordinates. + + Args: + input_frame: Image to crop + x_min, y_min, x_max, y_max: Coordinates of the bounding box + + Returns: + Cropped image + """ + # Adding +1 to exclude the bounding box pixels. + cropped_image = input_frame[int(y_min) + 1:int(y_max), int(x_min) + 1:int(x_max)] + return cropped_image diff --git a/python/pyarmnn/examples/common/network_executor.py b/python/pyarmnn/examples/common/network_executor.py index 6e2c53c43d..72262fc520 100644 --- a/python/pyarmnn/examples/common/network_executor.py +++ b/python/pyarmnn/examples/common/network_executor.py @@ -7,80 +7,6 @@ from typing import List, Tuple import pyarmnn as ann import numpy as np - -def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()): - """ - Creates a network based on the model file and a list of backends. - - Args: - model_file: User-specified model file. - backends: List of backends to optimize network. - input_names: - output_names: - - Returns: - net_id: Unique ID of the network to run. - runtime: Runtime context for executing inference. - input_binding_info: Contains essential information about the model input. - output_binding_info: Used to map output tensor and its memory. - """ - if not os.path.exists(model_file): - raise FileNotFoundError(f'Model file not found for: {model_file}') - - _, ext = os.path.splitext(model_file) - if ext == '.tflite': - parser = ann.ITfLiteParser() - else: - raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") - - network = parser.CreateNetworkFromBinaryFile(model_file) - - # Specify backends to optimize network - preferred_backends = [] - for b in backends: - preferred_backends.append(ann.BackendId(b)) - - # Select appropriate device context and optimize the network for that device - options = ann.CreationOptions() - runtime = ann.IRuntime(options) - opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), - ann.OptimizerOptions()) - print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n' - f'Optimization warnings: {messages}') - - # Load the optimized network onto the Runtime device - net_id, _ = runtime.LoadNetwork(opt_network) - - # Get input and output binding information - graph_id = parser.GetSubgraphCount() - 1 - input_names = parser.GetSubgraphInputTensorNames(graph_id) - input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) - output_names = parser.GetSubgraphOutputTensorNames(graph_id) - output_binding_info = [] - for output_name in output_names: - out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) - output_binding_info.append(out_bind_info) - return net_id, runtime, input_binding_info, output_binding_info - - -def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]: - """ - Executes inference for the loaded network. - - Args: - input_tensors: The input frame tensor. - output_tensors: The output tensor from output node. - runtime: Runtime context for executing inference. - net_id: Unique ID of the network to run. - - Returns: - list: Inference results as a list of ndarrays. - """ - runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) - output = ann.workload_tensors_to_ndarray(output_tensors) - return output - - class ArmnnNetworkExecutor: def __init__(self, model_file: str, backends: list): @@ -91,18 +17,145 @@ class ArmnnNetworkExecutor: model_file: User-specified model file. backends: List of backends to optimize network. """ - self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file, - backends) + self.model_file = model_file + self.backends = backends + self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network() self.output_tensors = ann.make_output_tensors(self.output_binding_info) - def run(self, input_tensors: list) -> List[np.ndarray]: + def run(self, input_data_list: list) -> List[np.ndarray]: """ - Executes inference for the loaded network. + Creates input tensors from input data and executes inference with the loaded network. Args: - input_tensors: The input frame tensor. + input_data_list: List of input frames. Returns: list: Inference results as a list of ndarrays. """ - return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id) + input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list) + self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors) + output = ann.workload_tensors_to_ndarray(self.output_tensors) + + return output + + def create_network(self): + """ + Creates a network based on the model file and a list of backends. + + Returns: + net_id: Unique ID of the network to run. + runtime: Runtime context for executing inference. + input_binding_info: Contains essential information about the model input. + output_binding_info: Used to map output tensor and its memory. + """ + if not os.path.exists(self.model_file): + raise FileNotFoundError(f'Model file not found for: {self.model_file}') + + _, ext = os.path.splitext(self.model_file) + if ext == '.tflite': + parser = ann.ITfLiteParser() + else: + raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") + + network = parser.CreateNetworkFromBinaryFile(self.model_file) + + # Specify backends to optimize network + preferred_backends = [] + for b in self.backends: + preferred_backends.append(ann.BackendId(b)) + + # Select appropriate device context and optimize the network for that device + options = ann.CreationOptions() + runtime = ann.IRuntime(options) + opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), + ann.OptimizerOptions()) + print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n' + f'Optimization warnings: {messages}') + + # Load the optimized network onto the Runtime device + net_id, _ = runtime.LoadNetwork(opt_network) + + # Get input and output binding information + graph_id = parser.GetSubgraphCount() - 1 + input_names = parser.GetSubgraphInputTensorNames(graph_id) + input_binding_info = [] + for input_name in input_names: + in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name) + input_binding_info.append(in_bind_info) + output_names = parser.GetSubgraphOutputTensorNames(graph_id) + output_binding_info = [] + for output_name in output_names: + out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) + output_binding_info.append(out_bind_info) + return net_id, runtime, input_binding_info, output_binding_info + + def get_data_type(self): + """ + Get the input data type of the initiated network. + + Returns: + numpy data type or None if doesn't exist in the if condition. + """ + if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32: + return np.float32 + elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8: + return np.uint8 + elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8: + return np.int8 + else: + return None + + def get_shape(self): + """ + Get the input shape of the initiated network. + + Returns: + tuple: The Shape of the network input. + """ + return tuple(self.input_binding_info[0][1].GetShape()) + + def get_input_quantization_scale(self, idx): + """ + Get the input quantization scale of the initiated network. + + Returns: + The quantization scale of the network input. + """ + return self.input_binding_info[idx][1].GetQuantizationScale() + + def get_input_quantization_offset(self, idx): + """ + Get the input quantization offset of the initiated network. + + Returns: + The quantization offset of the network input. + """ + return self.input_binding_info[idx][1].GetQuantizationOffset() + + def is_output_quantized(self, idx): + """ + Get True/False if output tensor is quantized or not respectively. + + Returns: + True if output is quantized and False otherwise. + """ + return self.output_binding_info[idx][1].IsQuantized() + + def get_output_quantization_scale(self, idx): + """ + Get the output quantization offset of the initiated network. + + Returns: + The quantization offset of the network output. + """ + return self.output_binding_info[idx][1].GetQuantizationScale() + + def get_output_quantization_offset(self, idx): + """ + Get the output quantization offset of the initiated network. + + Returns: + The quantization offset of the network output. + """ + return self.output_binding_info[idx][1].GetQuantizationOffset() + diff --git a/python/pyarmnn/examples/common/network_executor_tflite.py b/python/pyarmnn/examples/common/network_executor_tflite.py new file mode 100644 index 0000000000..10f5e6e6fb --- /dev/null +++ b/python/pyarmnn/examples/common/network_executor_tflite.py @@ -0,0 +1,98 @@ +# Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +# SPDX-License-Identifier: MIT + +import os +from typing import List, Tuple + +import numpy as np +from tflite_runtime import interpreter as tflite + +class TFLiteNetworkExecutor: + + def __init__(self, model_file: str, backends: list, delegate_path: str): + """ + Creates an inference executor for a given network and a list of backends. + + Args: + model_file: User-specified model file. + backends: List of backends to optimize network. + delegate_path: tflite delegate file path (.so). + """ + self.model_file = model_file + self.backends = backends + self.delegate_path = delegate_path + self.interpreter, self.input_details, self.output_details = self.create_network() + + def run(self, input_data_list: list) -> List[np.ndarray]: + """ + Executes inference for the loaded network. + + Args: + input_data_list: List of input frames. + + Returns: + list: Inference results as a list of ndarrays. + """ + output = [] + for index, input_data in enumerate(input_data_list): + self.interpreter.set_tensor(self.input_details[index]['index'], input_data) + self.interpreter.invoke() + for curr_output in self.output_details: + output.append(self.interpreter.get_tensor(curr_output['index'])) + + return output + + def create_network(self): + """ + Creates a network based on the model file and a list of backends. + + Returns: + interpreter: A TensorFlow Lite object for executing inference. + input_details: Contains essential information about the model input. + output_details: Used to map output tensor and its memory. + """ + + # Controls whether optimizations are used or not. + # Please note that optimizations can improve performance in some cases, but it can also + # degrade the performance in other cases. Accuracy might also be affected. + + optimization_enable = "true" + + if not os.path.exists(self.model_file): + raise FileNotFoundError(f'Model file not found for: {self.model_file}') + + _, ext = os.path.splitext(self.model_file) + if ext == '.tflite': + armnn_delegate = tflite.load_delegate(library=self.delegate_path, + options={"backends": ','.join(self.backends), "logging-severity": "info", + "enable-fast-math": optimization_enable, + "reduce-fp32-to-fp16": optimization_enable}) + interpreter = tflite.Interpreter(model_path=self.model_file, + experimental_delegates=[armnn_delegate]) + interpreter.allocate_tensors() + else: + raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") + + # Get input and output binding information + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + return interpreter, input_details, output_details + + def get_data_type(self): + """ + Get the input data type of the initiated network. + + Returns: + numpy data type or None if doesn't exist in the if condition. + """ + return self.input_details[0]['dtype'] + + def get_shape(self): + """ + Get the input shape of the initiated network. + + Returns: + tuple: The Shape of the network input. + """ + return tuple(self.input_details[0]['shape']) diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py index d4dadf80a4..beca0d37a0 100644 --- a/python/pyarmnn/examples/common/utils.py +++ b/python/pyarmnn/examples/common/utils.py @@ -8,7 +8,7 @@ import errno from pathlib import Path import numpy as np -import pyarmnn as ann +import datetime def dict_labels(labels_file_path: str, include_rgb=False) -> dict: @@ -42,67 +42,76 @@ def dict_labels(labels_file_path: str, include_rgb=False) -> dict: return labels -def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor): +def prepare_input_data(audio_data, input_data_type, input_quant_scale, input_quant_offset, mfcc_preprocessor): """ Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the input tensors. Args: audio_data: The audio data to process - mfcc_instance: the mfcc class instance - input_binding_info: the model input binding info - mfcc_preprocessor: the mfcc preprocessor instance + mfcc_instance: The mfcc class instance + input_data_type: The model's input data type + input_quant_scale: The model's quantization scale + input_quant_offset: The model's quantization offset + mfcc_preprocessor: The mfcc preprocessor instance Returns: - input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor + input_data: The prepared input data """ - data_type = input_binding_info[1].GetDataType() - input_tensor = mfcc_preprocessor.extract_features(audio_data) - if data_type != ann.DataType_Float32: - input_tensor = quantize_input(input_tensor, input_binding_info) - input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor]) - return input_tensors + input_data = mfcc_preprocessor.extract_features(audio_data) + if input_data_type != np.float32: + input_data = quantize_input(input_data, input_data_type, input_quant_scale, input_quant_offset) + return input_data -def quantize_input(data, input_binding_info): +def quantize_input(data, input_data_type, input_quant_scale, input_quant_offset): """Quantize the float input to (u)int8 ready for inputting to model.""" if data.ndim != 2: raise RuntimeError("Audio data must have 2 dimensions for quantization") - quant_scale = input_binding_info[1].GetQuantizationScale() - quant_offset = input_binding_info[1].GetQuantizationOffset() - data_type = input_binding_info[1].GetDataType() - - if data_type == ann.DataType_QAsymmS8: - data_type = np.int8 - elif data_type == ann.DataType_QAsymmU8: - data_type = np.uint8 - else: + if (input_data_type != np.int8) and (input_data_type != np.uint8): raise ValueError("Could not quantize data to required data type") - d_min = np.iinfo(data_type).min - d_max = np.iinfo(data_type).max + d_min = np.iinfo(input_data_type).min + d_max = np.iinfo(input_data_type).max for row in range(data.shape[0]): for col in range(data.shape[1]): - data[row, col] = (data[row, col] / quant_scale) + quant_offset + data[row, col] = (data[row, col] / input_quant_scale) + input_quant_offset data[row, col] = np.clip(data[row, col], d_min, d_max) - data = data.astype(data_type) + data = data.astype(input_data_type) return data -def dequantize_output(data, output_binding_info): +def dequantize_output(data, is_output_quantized, output_quant_scale, output_quant_offset): """Dequantize the (u)int8 output to float""" - if output_binding_info[1].IsQuantized(): + if is_output_quantized: if data.ndim != 2: raise RuntimeError("Data must have 2 dimensions for quantization") - quant_scale = output_binding_info[1].GetQuantizationScale() - quant_offset = output_binding_info[1].GetQuantizationOffset() - data = data.astype(float) for row in range(data.shape[0]): for col in range(data.shape[1]): - data[row, col] = (data[row, col] - quant_offset)*quant_scale + data[row, col] = (data[row, col] - output_quant_offset)*output_quant_scale return data + + +class Profiling: + def __init__(self, enabled: bool): + self.m_start = 0 + self.m_end = 0 + self.m_enabled = enabled + + def profiling_start(self): + if self.m_enabled: + self.m_start = datetime.datetime.now() + + def profiling_stop_and_print_us(self, msg): + if self.m_enabled: + self.m_end = datetime.datetime.now() + period = self.m_end - self.m_start + period_us = period.seconds * 1_000_000 + period.microseconds + print(f'Profiling: {msg} : {period_us:,} microSeconds') + return period_us + return 0 diff --git a/python/pyarmnn/examples/keyword_spotting/README.MD b/python/pyarmnn/examples/keyword_spotting/README.MD index d276c08f8e..dde8342e7f 100644 --- a/python/pyarmnn/examples/keyword_spotting/README.MD +++ b/python/pyarmnn/examples/keyword_spotting/README.MD @@ -166,7 +166,7 @@ mfcc_feats = np.dot(self._dct_matrix, log_mel_energy) # audio_utils.py # Quantize the input data and create input tensors with PyArmNN input_tensor = quantize_input(input_tensor, input_binding_info) -input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor]) +input_tensors = ann.make_input_tensors([input_binding_info], [input_data]) ``` Note: `ArmnnNetworkExecutor` has already created the output tensors for you. diff --git a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py index 6dfa4cc806..50ad1a8a2e 100644 --- a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py +++ b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py @@ -14,7 +14,7 @@ script_dir = os.path.dirname(__file__) sys.path.insert(1, os.path.join(script_dir, '..', 'common')) from network_executor import ArmnnNetworkExecutor -from utils import prepare_input_tensors, dequantize_output +from utils import prepare_input_data, dequantize_output from mfcc import AudioPreprocessor, MFCC, MFCCParams from audio_utils import decode, display_text from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio @@ -69,13 +69,16 @@ def parse_args(): def recognise_speech(audio_data, network, preprocessor, threshold): # Prepare the input Tensors - input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor) + input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0), + network.get_input_quantization_offset(0), preprocessor) # Run inference - output_result = network.run(input_tensors) + output_result = network.run([input_data]) dequantized_result = [] for index, ofm in enumerate(output_result): - dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index])) + dequantized_result.append(dequantize_output(ofm, network.is_output_quantized(index), + network.get_output_quantization_scale(index), + network.get_output_quantization_offset(index))) # Decode the text and display result if above threshold decoded_result = decode(dequantized_result, labels) diff --git a/python/pyarmnn/examples/object_detection/README.md b/python/pyarmnn/examples/object_detection/README.md index b63295cc89..7a946ad6f5 100644 --- a/python/pyarmnn/examples/object_detection/README.md +++ b/python/pyarmnn/examples/object_detection/README.md @@ -1,7 +1,27 @@ -# PyArmNN Object Detection Sample Application +# Object Detection Sample Application ## Introduction -This sample application guides the user and shows how to perform object detection using PyArmNN API. We assume the user has already built PyArmNN by following the instructions of the README in the main PyArmNN directory. +This sample application guides the user and shows how to perform object detection using PyArmNN or Arm NN TensorFlow Lite Delegate API. We assume the user has already built PyArmNN by following the instructions of the README in the main PyArmNN directory. + +##### Running with Armn NN TensorFlow Lite Delegate +There is an option to use the Arm NN TensorFlow Lite Delegate instead of Arm NN TensorFlow Lite Parser for the object detection inference. +The Arm NN TensorFlow Lite Delegate is part of Arm NN library and its purpose is to accelerate certain TensorFlow Lite +(TfLite) operators on Arm hardware. The main advantage of using the Arm NN TensorFlow Lite Delegate over the Arm NN TensorFlow +Lite Parser is that the number of supported operations is far greater, which means Arm NN TfLite Delegate can execute +all TfLite models, and accelerates any operations that Arm NN supports. +In addition, in the delegate options there are some optimizations applied by default in order to improve the inference +performance at the expanse of a slight accuracy reduction. In this example we enable fast math and reduce float32 to +float16 optimizations. + +Using the **fast_math** flag can lead to performance improvements in fp32 and fp16 layers but may result in +results with reduced or different precision. The fast_math flag will not have any effect on int8 performance. + +The **reduce-fp32-to-fp16** feature works best if all operators of the model are in Fp32. ArmNN will add conversion layers +between layers that weren't in Fp32 in the first place or if the operator is not supported in Fp16. +The overhead of these conversions can lead to a slower overall performance if too many conversions are required. + +One can turn off these optimizations in the `create_network` function found in the `network_executor_tflite.py`. +Just change the `optimization_enable` flag to false. We provide example scripts for performing object detection from video file and video stream with `run_video_file.py` and `run_video_stream.py`. @@ -9,6 +29,17 @@ The application takes a model and video file or camera feed as input, runs infer A similar implementation of this object detection application is also provided in C++ in the examples for ArmNN. +##### Performing Object Detection with Style Transfer and TensorFlow Lite Delegate +In addition to running Object Detection using TensorFlow Lite Delegate, instead of drawing bounding boxes on each frame, there is an option to run style transfer to create stylized detections. +Style transfer is the ability to create a new image, known as a pastiche, based on two input images: one representing an artistic style and one representing the content frame containing class detections. +The style transfer consists of two submodels: +Style Prediction Model: A MobilenetV2-based neural network that takes an input style image to create a style bottleneck vector. +Style Transform Model: A neural network that applies a style bottleneck vector to a content image and creates a stylized image. +An image containing an art style is preprocessed to a correct size and dimension. +The preprocessed style image is passed to a style predict network which calculates and returns a style bottleneck tensor. +The style transfer network receives the style bottleneck, and a content frame that contains detections, which then transforms the requested class detected and returns a stylized frame. + + ## Prerequisites ##### PyArmNN @@ -30,7 +61,19 @@ $ python -c "import pyarmnn as ann;print(ann.GetVersion())" Install the following libraries on your system: ```bash -$ sudo apt-get install python3-opencv libqtgui4 libqt4-test +$ sudo apt-get install python3-opencv +``` + + +This section is needed only if running with Arm NN TensorFlow Lite Delegate is desired\ +If there is no libarmnnDelegate.so file in your ARMNN_LIB path, +download Arm NN artifacts with Arm NN delegate according to your platform and Arm NN latest version (for this example aarch64 and v21.11 respectively): +```bash +$ export $WORKSPACE=`pwd` +$ mkdir ./armnn_artifacts ; cd armnn_artifacts +$ wget https://github.com/ARM-software/armnn/releases/download/v21.11/ArmNN-linux-aarch64.tar.gz +$ tar -xvzf ArmNN*.tar.gz +$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`pwd` ``` Create a virtual environment: @@ -39,8 +82,11 @@ $ python3.7 -m venv devenv --system-site-packages $ source devenv/bin/activate ``` -Install the dependencies: +Install the dependencies from the object_detection example folder: +* In case the python version is 3.8 or lower, tflite_runtime version 2.5.0 (without post1 suffix) should be installed. + (requirements.txt file should be amended) ```bash +$ cd $WORKSPACE/armnn/python/pyarmnn/examples/object_detection $ pip install -r requirements.txt ``` @@ -69,10 +115,27 @@ The user can specify these arguments at command line: * `--preferred_backends` - You can specify one or more backend in order of preference. Accepted backends include `CpuAcc, GpuAcc, CpuRef`. Arm NN will decide which layers of the network are supported by the backend, falling back to the next if a layer is unsupported. Defaults to `['CpuAcc', 'CpuRef']` +* `--tflite_delegate_path` - Optional. Path to the Arm NN TensorFlow Lite Delegate library (libarmnnDelegate.so). If provided, Arm NN TensorFlow Lite Delegate will be used instead of PyArmNN. + +* `--profiling_enabled` - Optional. Enabling this option will print important ML related milestones timing information in micro-seconds. By default, this option is disabled. Accepted options are `true/false` + +The `run_video_file.py` example can also perform style transfer on a selected class of detected objects, and stylize the detections based on a given style image. + +In addition, to run style transfer, the user needs to specify these arguments at command line: + +* `--style_predict_model_file_path` - Path to the style predict model that will be used to create a style bottleneck tensor + +* `--style_transfer_model_file_path` - Path to the style transfer model to use which will perform the style transfer + +* `--style_image_path` - Path to a .jpg/jpeg/png style image to create stylized frames + +* `--style_transfer_class` - A detected class name to transform its style + Run the sample script: ```bash -$ python run_video_file.py --video_file_path --model_file_path --model_name +$ python run_video_file.py --video_file_path --model_file_path --model_name --tflite_delegate_path --style_predict_model_file_path +--style_transfer_model_file_path --style_image_path --style_transfer_class ``` ## Object Detection from Video Stream @@ -94,16 +157,51 @@ The user can specify these arguments at command line: * `--preferred_backends` - You can specify one or more backend in order of preference. Accepted backends include `CpuAcc, GpuAcc, CpuRef`. Arm NN will decide which layers of the network are supported by the backend, falling back to the next if a layer is unsupported. Defaults to `['CpuAcc', 'CpuRef']` +* `--tflite_delegate_path` - Optional. Path to the Arm NN TensorFlow Lite Delegate library (libarmnnDelegate.so). If provided, Arm NN TensorFlow Lite Delegate will be used instead of PyArmNN. + +* `--profiling_enabled` - Optional. Enabling this option will print important ML related milestones timing information in micro-seconds. By default, this option is disabled. Accepted options are `true/false` + +Run the sample script: +```bash +$ python run_video_stream.py --model_file_path --model_name --tflite_delegate_path --label_path --video_file_path