aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaviv Shalev <raviv.shalev@arm.com>2021-12-07 15:18:09 +0200
committerTeresaARM <teresa.charlinreyes@arm.com>2022-04-13 15:33:31 +0000
commit97ddc06e52fbcabfd8ede7a00e9494c663186b92 (patch)
tree43c84d352c3a67aa45d89760fba6c79b81c8f8dc
parent2f0ddb67d8f9267ab600a8a26308cab32f9e16ac (diff)
downloadarmnn-97ddc06e52fbcabfd8ede7a00e9494c663186b92.tar.gz
MLECO-2493 Add python OD example with TFLite delegate
Signed-off-by: Raviv Shalev <raviv.shalev@arm.com> Change-Id: I25fcccbf912be0c5bd4fbfd2e97552341958af35
-rw-r--r--python/pyarmnn/examples/common/cv_utils.py58
-rw-r--r--python/pyarmnn/examples/common/network_executor.py213
-rw-r--r--python/pyarmnn/examples/common/network_executor_tflite.py98
-rw-r--r--python/pyarmnn/examples/common/utils.py73
-rw-r--r--python/pyarmnn/examples/keyword_spotting/README.MD2
-rw-r--r--python/pyarmnn/examples/keyword_spotting/run_audio_classification.py11
-rw-r--r--python/pyarmnn/examples/object_detection/README.md126
-rw-r--r--python/pyarmnn/examples/object_detection/requirements.txt2
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_file.py84
-rw-r--r--python/pyarmnn/examples/object_detection/run_video_stream.py82
-rw-r--r--python/pyarmnn/examples/object_detection/style_transfer.py138
-rw-r--r--python/pyarmnn/examples/object_detection/yolo.py6
-rw-r--r--python/pyarmnn/examples/speech_recognition/README.md4
-rw-r--r--python/pyarmnn/examples/speech_recognition/run_audio_file.py7
-rw-r--r--python/pyarmnn/examples/tests/conftest.py20
-rw-r--r--python/pyarmnn/examples/tests/context.py6
-rw-r--r--python/pyarmnn/examples/tests/test_common_utils.py23
-rw-r--r--python/pyarmnn/examples/tests/test_network_executor.py24
-rw-r--r--python/pyarmnn/examples/tests/test_style_transfer.py70
19 files changed, 858 insertions, 189 deletions
diff --git a/python/pyarmnn/examples/common/cv_utils.py b/python/pyarmnn/examples/common/cv_utils.py
index e12ff50548..36d1039227 100644
--- a/python/pyarmnn/examples/common/cv_utils.py
+++ b/python/pyarmnn/examples/common/cv_utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2020-2022 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""
@@ -11,29 +11,35 @@ import os
import cv2
import numpy as np
-import pyarmnn as ann
-
-def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool):
+def preprocess(frame: np.ndarray, input_data_type, input_data_shape: tuple, is_normalised: bool,
+ keep_aspect_ratio: bool=True):
"""
Takes a frame, resizes, swaps channels and converts data type to match
- model input layer. The converted frame is wrapped in a const tensor
- and bound to the input tensor.
+ model input layer.
Args:
frame: Captured frame from video.
- input_binding_info: Contains shape and data type of model input layer.
+ input_data_type: Contains data type of model input layer.
+ input_data_shape: Contains shape of model input layer.
is_normalised: if the input layer expects normalised data
+ keep_aspect_ratio: Network executor's input data aspect ratio
Returns:
Input tensor.
"""
- # Swap channels and resize frame to model resolution
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
+ if keep_aspect_ratio:
+ # Swap channels and resize frame to model resolution
+ resized_frame = resize_with_aspect_ratio(frame, input_data_shape)
+ else:
+ # select the height and width from input_data_shape
+ frame_height = input_data_shape[1]
+ frame_width = input_data_shape[2]
+ resized_frame = cv2.resize(frame, (frame_width, frame_height))
# Expand dimensions and convert data type to match model input
- if input_binding_info[1].GetDataType() == ann.DataType_Float32:
+ if np.float32 == input_data_type:
data_type = np.float32
if is_normalised:
resized_frame = resized_frame.astype("float32")/255
@@ -41,26 +47,24 @@ def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool
data_type = np.uint8
resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
- assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
-
- input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
- return input_tensors
+ assert resized_frame.shape == input_data_shape
+ return resized_frame
-def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
+def resize_with_aspect_ratio(frame: np.ndarray, input_data_shape: tuple):
"""
Resizes frame while maintaining aspect ratio, padding any empty space.
Args:
frame: Captured frame.
- input_binding_info: Contains shape of model input layer.
+ input_data_shape: Contains shape of model input layer.
Returns:
Frame resized to the size of model input layer.
"""
aspect_ratio = frame.shape[1] / frame.shape[0]
- model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
+ _, model_height, model_width, _ = input_data_shape
if aspect_ratio >= 1.0:
new_height, new_width = int(model_width / aspect_ratio), model_width
@@ -173,14 +177,14 @@ def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labe
# Create label for detected object class
label = f'{label} {confidence * 100:.1f}%'
- label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
+ label_color = (0, 0, 0) if sum(color) > 200 else (255, 255, 255)
# Make sure label always stays on-screen
x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
- lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
+ lbl_text_pos = (x_min + 5, y_min + 16 if y_min < 25 else y_min - 5)
# Add label and confidence value
cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
@@ -190,3 +194,19 @@ def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labe
def get_source_encoding_int(video_capture):
return int(video_capture.get(cv2.CAP_PROP_FOURCC))
+
+
+def crop_bounding_box_object(input_frame: np.ndarray, x_min: float, y_min: float, x_max: float, y_max: float):
+ """
+ Creates a cropped image based on x and y coordinates.
+
+ Args:
+ input_frame: Image to crop
+ x_min, y_min, x_max, y_max: Coordinates of the bounding box
+
+ Returns:
+ Cropped image
+ """
+ # Adding +1 to exclude the bounding box pixels.
+ cropped_image = input_frame[int(y_min) + 1:int(y_max), int(x_min) + 1:int(x_max)]
+ return cropped_image
diff --git a/python/pyarmnn/examples/common/network_executor.py b/python/pyarmnn/examples/common/network_executor.py
index 6e2c53c43d..72262fc520 100644
--- a/python/pyarmnn/examples/common/network_executor.py
+++ b/python/pyarmnn/examples/common/network_executor.py
@@ -7,80 +7,6 @@ from typing import List, Tuple
import pyarmnn as ann
import numpy as np
-
-def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()):
- """
- Creates a network based on the model file and a list of backends.
-
- Args:
- model_file: User-specified model file.
- backends: List of backends to optimize network.
- input_names:
- output_names:
-
- Returns:
- net_id: Unique ID of the network to run.
- runtime: Runtime context for executing inference.
- input_binding_info: Contains essential information about the model input.
- output_binding_info: Used to map output tensor and its memory.
- """
- if not os.path.exists(model_file):
- raise FileNotFoundError(f'Model file not found for: {model_file}')
-
- _, ext = os.path.splitext(model_file)
- if ext == '.tflite':
- parser = ann.ITfLiteParser()
- else:
- raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
-
- network = parser.CreateNetworkFromBinaryFile(model_file)
-
- # Specify backends to optimize network
- preferred_backends = []
- for b in backends:
- preferred_backends.append(ann.BackendId(b))
-
- # Select appropriate device context and optimize the network for that device
- options = ann.CreationOptions()
- runtime = ann.IRuntime(options)
- opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
- ann.OptimizerOptions())
- print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n'
- f'Optimization warnings: {messages}')
-
- # Load the optimized network onto the Runtime device
- net_id, _ = runtime.LoadNetwork(opt_network)
-
- # Get input and output binding information
- graph_id = parser.GetSubgraphCount() - 1
- input_names = parser.GetSubgraphInputTensorNames(graph_id)
- input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
- output_names = parser.GetSubgraphOutputTensorNames(graph_id)
- output_binding_info = []
- for output_name in output_names:
- out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
- output_binding_info.append(out_bind_info)
- return net_id, runtime, input_binding_info, output_binding_info
-
-
-def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]:
- """
- Executes inference for the loaded network.
-
- Args:
- input_tensors: The input frame tensor.
- output_tensors: The output tensor from output node.
- runtime: Runtime context for executing inference.
- net_id: Unique ID of the network to run.
-
- Returns:
- list: Inference results as a list of ndarrays.
- """
- runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
- output = ann.workload_tensors_to_ndarray(output_tensors)
- return output
-
-
class ArmnnNetworkExecutor:
def __init__(self, model_file: str, backends: list):
@@ -91,18 +17,145 @@ class ArmnnNetworkExecutor:
model_file: User-specified model file.
backends: List of backends to optimize network.
"""
- self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file,
- backends)
+ self.model_file = model_file
+ self.backends = backends
+ self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network()
self.output_tensors = ann.make_output_tensors(self.output_binding_info)
- def run(self, input_tensors: list) -> List[np.ndarray]:
+ def run(self, input_data_list: list) -> List[np.ndarray]:
"""
- Executes inference for the loaded network.
+ Creates input tensors from input data and executes inference with the loaded network.
Args:
- input_tensors: The input frame tensor.
+ input_data_list: List of input frames.
Returns:
list: Inference results as a list of ndarrays.
"""
- return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id)
+ input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list)
+ self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors)
+ output = ann.workload_tensors_to_ndarray(self.output_tensors)
+
+ return output
+
+ def create_network(self):
+ """
+ Creates a network based on the model file and a list of backends.
+
+ Returns:
+ net_id: Unique ID of the network to run.
+ runtime: Runtime context for executing inference.
+ input_binding_info: Contains essential information about the model input.
+ output_binding_info: Used to map output tensor and its memory.
+ """
+ if not os.path.exists(self.model_file):
+ raise FileNotFoundError(f'Model file not found for: {self.model_file}')
+
+ _, ext = os.path.splitext(self.model_file)
+ if ext == '.tflite':
+ parser = ann.ITfLiteParser()
+ else:
+ raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
+
+ network = parser.CreateNetworkFromBinaryFile(self.model_file)
+
+ # Specify backends to optimize network
+ preferred_backends = []
+ for b in self.backends:
+ preferred_backends.append(ann.BackendId(b))
+
+ # Select appropriate device context and optimize the network for that device
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+ ann.OptimizerOptions())
+ print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n'
+ f'Optimization warnings: {messages}')
+
+ # Load the optimized network onto the Runtime device
+ net_id, _ = runtime.LoadNetwork(opt_network)
+
+ # Get input and output binding information
+ graph_id = parser.GetSubgraphCount() - 1
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ input_binding_info = []
+ for input_name in input_names:
+ in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name)
+ input_binding_info.append(in_bind_info)
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+ output_binding_info = []
+ for output_name in output_names:
+ out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
+ output_binding_info.append(out_bind_info)
+ return net_id, runtime, input_binding_info, output_binding_info
+
+ def get_data_type(self):
+ """
+ Get the input data type of the initiated network.
+
+ Returns:
+ numpy data type or None if doesn't exist in the if condition.
+ """
+ if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32:
+ return np.float32
+ elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8:
+ return np.uint8
+ elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8:
+ return np.int8
+ else:
+ return None
+
+ def get_shape(self):
+ """
+ Get the input shape of the initiated network.
+
+ Returns:
+ tuple: The Shape of the network input.
+ """
+ return tuple(self.input_binding_info[0][1].GetShape())
+
+ def get_input_quantization_scale(self, idx):
+ """
+ Get the input quantization scale of the initiated network.
+
+ Returns:
+ The quantization scale of the network input.
+ """
+ return self.input_binding_info[idx][1].GetQuantizationScale()
+
+ def get_input_quantization_offset(self, idx):
+ """
+ Get the input quantization offset of the initiated network.
+
+ Returns:
+ The quantization offset of the network input.
+ """
+ return self.input_binding_info[idx][1].GetQuantizationOffset()
+
+ def is_output_quantized(self, idx):
+ """
+ Get True/False if output tensor is quantized or not respectively.
+
+ Returns:
+ True if output is quantized and False otherwise.
+ """
+ return self.output_binding_info[idx][1].IsQuantized()
+
+ def get_output_quantization_scale(self, idx):
+ """
+ Get the output quantization offset of the initiated network.
+
+ Returns:
+ The quantization offset of the network output.
+ """
+ return self.output_binding_info[idx][1].GetQuantizationScale()
+
+ def get_output_quantization_offset(self, idx):
+ """
+ Get the output quantization offset of the initiated network.
+
+ Returns:
+ The quantization offset of the network output.
+ """
+ return self.output_binding_info[idx][1].GetQuantizationOffset()
+
diff --git a/python/pyarmnn/examples/common/network_executor_tflite.py b/python/pyarmnn/examples/common/network_executor_tflite.py
new file mode 100644
index 0000000000..10f5e6e6fb
--- /dev/null
+++ b/python/pyarmnn/examples/common/network_executor_tflite.py
@@ -0,0 +1,98 @@
+# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+from typing import List, Tuple
+
+import numpy as np
+from tflite_runtime import interpreter as tflite
+
+class TFLiteNetworkExecutor:
+
+ def __init__(self, model_file: str, backends: list, delegate_path: str):
+ """
+ Creates an inference executor for a given network and a list of backends.
+
+ Args:
+ model_file: User-specified model file.
+ backends: List of backends to optimize network.
+ delegate_path: tflite delegate file path (.so).
+ """
+ self.model_file = model_file
+ self.backends = backends
+ self.delegate_path = delegate_path
+ self.interpreter, self.input_details, self.output_details = self.create_network()
+
+ def run(self, input_data_list: list) -> List[np.ndarray]:
+ """
+ Executes inference for the loaded network.
+
+ Args:
+ input_data_list: List of input frames.
+
+ Returns:
+ list: Inference results as a list of ndarrays.
+ """
+ output = []
+ for index, input_data in enumerate(input_data_list):
+ self.interpreter.set_tensor(self.input_details[index]['index'], input_data)
+ self.interpreter.invoke()
+ for curr_output in self.output_details:
+ output.append(self.interpreter.get_tensor(curr_output['index']))
+
+ return output
+
+ def create_network(self):
+ """
+ Creates a network based on the model file and a list of backends.
+
+ Returns:
+ interpreter: A TensorFlow Lite object for executing inference.
+ input_details: Contains essential information about the model input.
+ output_details: Used to map output tensor and its memory.
+ """
+
+ # Controls whether optimizations are used or not.
+ # Please note that optimizations can improve performance in some cases, but it can also
+ # degrade the performance in other cases. Accuracy might also be affected.
+
+ optimization_enable = "true"
+
+ if not os.path.exists(self.model_file):
+ raise FileNotFoundError(f'Model file not found for: {self.model_file}')
+
+ _, ext = os.path.splitext(self.model_file)
+ if ext == '.tflite':
+ armnn_delegate = tflite.load_delegate(library=self.delegate_path,
+ options={"backends": ','.join(self.backends), "logging-severity": "info",
+ "enable-fast-math": optimization_enable,
+ "reduce-fp32-to-fp16": optimization_enable})
+ interpreter = tflite.Interpreter(model_path=self.model_file,
+ experimental_delegates=[armnn_delegate])
+ interpreter.allocate_tensors()
+ else:
+ raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
+
+ # Get input and output binding information
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ return interpreter, input_details, output_details
+
+ def get_data_type(self):
+ """
+ Get the input data type of the initiated network.
+
+ Returns:
+ numpy data type or None if doesn't exist in the if condition.
+ """
+ return self.input_details[0]['dtype']
+
+ def get_shape(self):
+ """
+ Get the input shape of the initiated network.
+
+ Returns:
+ tuple: The Shape of the network input.
+ """
+ return tuple(self.input_details[0]['shape'])
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py
index d4dadf80a4..beca0d37a0 100644
--- a/python/pyarmnn/examples/common/utils.py
+++ b/python/pyarmnn/examples/common/utils.py
@@ -8,7 +8,7 @@ import errno
from pathlib import Path
import numpy as np
-import pyarmnn as ann
+import datetime
def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
@@ -42,67 +42,76 @@ def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
return labels
-def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
+def prepare_input_data(audio_data, input_data_type, input_quant_scale, input_quant_offset, mfcc_preprocessor):
"""
Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
input tensors.
Args:
audio_data: The audio data to process
- mfcc_instance: the mfcc class instance
- input_binding_info: the model input binding info
- mfcc_preprocessor: the mfcc preprocessor instance
+ mfcc_instance: The mfcc class instance
+ input_data_type: The model's input data type
+ input_quant_scale: The model's quantization scale
+ input_quant_offset: The model's quantization offset
+ mfcc_preprocessor: The mfcc preprocessor instance
Returns:
- input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
+ input_data: The prepared input data
"""
- data_type = input_binding_info[1].GetDataType()
- input_tensor = mfcc_preprocessor.extract_features(audio_data)
- if data_type != ann.DataType_Float32:
- input_tensor = quantize_input(input_tensor, input_binding_info)
- input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
- return input_tensors
+ input_data = mfcc_preprocessor.extract_features(audio_data)
+ if input_data_type != np.float32:
+ input_data = quantize_input(input_data, input_data_type, input_quant_scale, input_quant_offset)
+ return input_data
-def quantize_input(data, input_binding_info):
+def quantize_input(data, input_data_type, input_quant_scale, input_quant_offset):
"""Quantize the float input to (u)int8 ready for inputting to model."""
if data.ndim != 2:
raise RuntimeError("Audio data must have 2 dimensions for quantization")
- quant_scale = input_binding_info[1].GetQuantizationScale()
- quant_offset = input_binding_info[1].GetQuantizationOffset()
- data_type = input_binding_info[1].GetDataType()
-
- if data_type == ann.DataType_QAsymmS8:
- data_type = np.int8
- elif data_type == ann.DataType_QAsymmU8:
- data_type = np.uint8
- else:
+ if (input_data_type != np.int8) and (input_data_type != np.uint8):
raise ValueError("Could not quantize data to required data type")
- d_min = np.iinfo(data_type).min
- d_max = np.iinfo(data_type).max
+ d_min = np.iinfo(input_data_type).min
+ d_max = np.iinfo(input_data_type).max
for row in range(data.shape[0]):
for col in range(data.shape[1]):
- data[row, col] = (data[row, col] / quant_scale) + quant_offset
+ data[row, col] = (data[row, col] / input_quant_scale) + input_quant_offset
data[row, col] = np.clip(data[row, col], d_min, d_max)
- data = data.astype(data_type)
+ data = data.astype(input_data_type)
return data
-def dequantize_output(data, output_binding_info):
+def dequantize_output(data, is_output_quantized, output_quant_scale, output_quant_offset):
"""Dequantize the (u)int8 output to float"""
- if output_binding_info[1].IsQuantized():
+ if is_output_quantized:
if data.ndim != 2:
raise RuntimeError("Data must have 2 dimensions for quantization")
- quant_scale = output_binding_info[1].GetQuantizationScale()
- quant_offset = output_binding_info[1].GetQuantizationOffset()
-
data = data.astype(float)
for row in range(data.shape[0]):
for col in range(data.shape[1]):
- data[row, col] = (data[row, col] - quant_offset)*quant_scale
+ data[row, col] = (data[row, col] - output_quant_offset)*output_quant_scale
return data
+
+
+class Profiling:
+ def __init__(self, enabled: bool):
+ self.m_start = 0
+ self.m_end = 0
+ self.m_enabled = enabled
+
+ def profiling_start(self):
+ if self.m_enabled:
+ self.m_start = datetime.datetime.now()
+
+ def profiling_stop_and_print_us(self, msg):
+ if self.m_enabled:
+ self.m_end = datetime.datetime.now()
+ period = self.m_end - self.m_start
+ period_us = period.seconds * 1_000_000 + period.microseconds
+ print(f'Profiling: {msg} : {period_us:,} microSeconds')
+ return period_us
+ return 0
diff --git a/python/pyarmnn/examples/keyword_spotting/README.MD b/python/pyarmnn/examples/keyword_spotting/README.MD
index d276c08f8e..dde8342e7f 100644
--- a/python/pyarmnn/examples/keyword_spotting/README.MD
+++ b/python/pyarmnn/examples/keyword_spotting/README.MD
@@ -166,7 +166,7 @@ mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
# audio_utils.py
# Quantize the input data and create input tensors with PyArmNN
input_tensor = quantize_input(input_tensor, input_binding_info)
-input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+input_tensors = ann.make_input_tensors([input_binding_info], [input_data])
```
Note: `ArmnnNetworkExecutor` has already created the output tensors for you.
diff --git a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
index 6dfa4cc806..50ad1a8a2e 100644
--- a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
+++ b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
@@ -14,7 +14,7 @@ script_dir = os.path.dirname(__file__)
sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
from network_executor import ArmnnNetworkExecutor
-from utils import prepare_input_tensors, dequantize_output
+from utils import prepare_input_data, dequantize_output
from mfcc import AudioPreprocessor, MFCC, MFCCParams
from audio_utils import decode, display_text
from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
@@ -69,13 +69,16 @@ def parse_args():
def recognise_speech(audio_data, network, preprocessor, threshold):
# Prepare the input Tensors
- input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+ input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0),
+ network.get_input_quantization_offset(0), preprocessor)
# Run inference
- output_result = network.run(input_tensors)
+ output_result = network.run([input_data])
dequantized_result = []
for index, ofm in enumerate(output_result):
- dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index]))
+ dequantized_result.append(dequantize_output(ofm, network.is_output_quantized(index),
+ network.get_output_quantization_scale(index),
+ network.get_output_quantization_offset(index)))
# Decode the text and display result if above threshold
decoded_result = decode(dequantized_result, labels)
diff --git a/python/pyarmnn/examples/object_detection/README.md b/python/pyarmnn/examples/object_detection/README.md
index b63295cc89..7a946ad6f5 100644
--- a/python/pyarmnn/examples/object_detection/README.md
+++ b/python/pyarmnn/examples/object_detection/README.md
@@ -1,7 +1,27 @@
-# PyArmNN Object Detection Sample Application
+# Object Detection Sample Application
## Introduction
-This sample application guides the user and shows how to perform object detection using PyArmNN API. We assume the user has already built PyArmNN by following the instructions of the README in the main PyArmNN directory.
+This sample application guides the user and shows how to perform object detection using PyArmNN or Arm NN TensorFlow Lite Delegate API. We assume the user has already built PyArmNN by following the instructions of the README in the main PyArmNN directory.
+
+##### Running with Armn NN TensorFlow Lite Delegate
+There is an option to use the Arm NN TensorFlow Lite Delegate instead of Arm NN TensorFlow Lite Parser for the object detection inference.
+The Arm NN TensorFlow Lite Delegate is part of Arm NN library and its purpose is to accelerate certain TensorFlow Lite
+(TfLite) operators on Arm hardware. The main advantage of using the Arm NN TensorFlow Lite Delegate over the Arm NN TensorFlow
+Lite Parser is that the number of supported operations is far greater, which means Arm NN TfLite Delegate can execute
+all TfLite models, and accelerates any operations that Arm NN supports.
+In addition, in the delegate options there are some optimizations applied by default in order to improve the inference
+performance at the expanse of a slight accuracy reduction. In this example we enable fast math and reduce float32 to
+float16 optimizations.
+
+Using the **fast_math** flag can lead to performance improvements in fp32 and fp16 layers but may result in
+results with reduced or different precision. The fast_math flag will not have any effect on int8 performance.
+
+The **reduce-fp32-to-fp16** feature works best if all operators of the model are in Fp32. ArmNN will add conversion layers
+between layers that weren't in Fp32 in the first place or if the operator is not supported in Fp16.
+The overhead of these conversions can lead to a slower overall performance if too many conversions are required.
+
+One can turn off these optimizations in the `create_network` function found in the `network_executor_tflite.py`.
+Just change the `optimization_enable` flag to false.
We provide example scripts for performing object detection from video file and video stream with `run_video_file.py` and `run_video_stream.py`.
@@ -9,6 +29,17 @@ The application takes a model and video file or camera feed as input, runs infer
A similar implementation of this object detection application is also provided in C++ in the examples for ArmNN.
+##### Performing Object Detection with Style Transfer and TensorFlow Lite Delegate
+In addition to running Object Detection using TensorFlow Lite Delegate, instead of drawing bounding boxes on each frame, there is an option to run style transfer to create stylized detections.
+Style transfer is the ability to create a new image, known as a pastiche, based on two input images: one representing an artistic style and one representing the content frame containing class detections.
+The style transfer consists of two submodels:
+Style Prediction Model: A MobilenetV2-based neural network that takes an input style image to create a style bottleneck vector.
+Style Transform Model: A neural network that applies a style bottleneck vector to a content image and creates a stylized image.
+An image containing an art style is preprocessed to a correct size and dimension.
+The preprocessed style image is passed to a style predict network which calculates and returns a style bottleneck tensor.
+The style transfer network receives the style bottleneck, and a content frame that contains detections, which then transforms the requested class detected and returns a stylized frame.
+
+
## Prerequisites
##### PyArmNN
@@ -30,7 +61,19 @@ $ python -c "import pyarmnn as ann;print(ann.GetVersion())"
Install the following libraries on your system:
```bash
-$ sudo apt-get install python3-opencv libqtgui4 libqt4-test
+$ sudo apt-get install python3-opencv
+```
+
+
+<b>This section is needed only if running with Arm NN TensorFlow Lite Delegate is desired</b>\
+If there is no libarmnnDelegate.so file in your ARMNN_LIB path,
+download Arm NN artifacts with Arm NN delegate according to your platform and Arm NN latest version (for this example aarch64 and v21.11 respectively):
+```bash
+$ export $WORKSPACE=`pwd`
+$ mkdir ./armnn_artifacts ; cd armnn_artifacts
+$ wget https://github.com/ARM-software/armnn/releases/download/v21.11/ArmNN-linux-aarch64.tar.gz
+$ tar -xvzf ArmNN*.tar.gz
+$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`pwd`
```
Create a virtual environment:
@@ -39,8 +82,11 @@ $ python3.7 -m venv devenv --system-site-packages
$ source devenv/bin/activate
```
-Install the dependencies:
+Install the dependencies from the object_detection example folder:
+* In case the python version is 3.8 or lower, tflite_runtime version 2.5.0 (without post1 suffix) should be installed.
+ (requirements.txt file should be amended)
```bash
+$ cd $WORKSPACE/armnn/python/pyarmnn/examples/object_detection
$ pip install -r requirements.txt
```
@@ -69,10 +115,27 @@ The user can specify these arguments at command line:
* `--preferred_backends` - You can specify one or more backend in order of preference. Accepted backends include `CpuAcc, GpuAcc, CpuRef`. Arm NN will decide which layers of the network are supported by the backend, falling back to the next if a layer is unsupported. Defaults to `['CpuAcc', 'CpuRef']`
+* `--tflite_delegate_path` - Optional. Path to the Arm NN TensorFlow Lite Delegate library (libarmnnDelegate.so). If provided, Arm NN TensorFlow Lite Delegate will be used instead of PyArmNN.
+
+* `--profiling_enabled` - Optional. Enabling this option will print important ML related milestones timing information in micro-seconds. By default, this option is disabled. Accepted options are `true/false`
+
+The `run_video_file.py` example can also perform style transfer on a selected class of detected objects, and stylize the detections based on a given style image.
+
+In addition, to run style transfer, the user needs to specify these arguments at command line:
+
+* `--style_predict_model_file_path` - Path to the style predict model that will be used to create a style bottleneck tensor
+
+* `--style_transfer_model_file_path` - Path to the style transfer model to use which will perform the style transfer
+
+* `--style_image_path` - Path to a .jpg/jpeg/png style image to create stylized frames
+
+* `--style_transfer_class` - A detected class name to transform its style
+
Run the sample script:
```bash
-$ python run_video_file.py --video_file_path <video_file_path> --model_file_path <model_file_path> --model_name <model_name>
+$ python run_video_file.py --video_file_path <video_file_path> --model_file_path <model_file_path> --model_name <model_name> --tflite_delegate_path <ARMNN delegate file path> --style_predict_model_file_path <style_predict_model_path>
+--style_transfer_model_file_path <style_transfer_model_path> --style_image_path <style_image_path> --style_transfer_class <style_transfer_class>
```
## Object Detection from Video Stream
@@ -94,16 +157,51 @@ The user can specify these arguments at command line:
* `--preferred_backends` - You can specify one or more backend in order of preference. Accepted backends include `CpuAcc, GpuAcc, CpuRef`. Arm NN will decide which layers of the network are supported by the backend, falling back to the next if a layer is unsupported. Defaults to `['CpuAcc', 'CpuRef']`
+* `--tflite_delegate_path` - Optional. Path to the Arm NN TensorFlow Lite Delegate library (libarmnnDelegate.so). If provided, Arm NN TensorFlow Lite Delegate will be used instead of PyArmNN.
+
+* `--profiling_enabled` - Optional. Enabling this option will print important ML related milestones timing information in micro-seconds. By default, this option is disabled. Accepted options are `true/false`
+
+Run the sample script:
+```bash
+$ python run_video_stream.py --model_file_path <model_file_path> --model_name <model_name> --tflite_delegate_path <ARMNN delegate file path> --label_path <Model label path> --video_file_path <Video file>
+
+In addition, to run style trasnfer, the user needs to specify these arguments at command line:
+
+* `--style_predict_model_file_path` - Path to .tflite style predict model that will be used to create a style bottleneck tensor
+
+* `--style_transfer_model_file_path` - Path to .tflite style transfer model to use which will perform the style transfer
+
+* `--style_image_path` - Path to a .jpg/jpeg/png style image to create stylized frames
+
+* `--style_transfer_class` - A detected class name to transform its style
Run the sample script:
```bash
-$ python run_video_stream.py --model_file_path <model_file_path> --model_name <model_name>
+$ python run_video_stream.py --model_file_path <model_file_path> --model_name <model_name> --tflite_delegate_path <ARMNN delegate file path> --style_predict_model_file_path <style_predict_model_path>
+--style_transfer_model_file_path <style_transfer_model_path> --style_image_path <style_image_path> --style_transfer_class <style_transfer_class>
```
-This application has been verified to work against the MobileNet SSD model, which can be downloaded along with it's label set from:
+This application has been verified to work against the MobileNet SSD model and YOLOv3, which can be downloaded along with it's label set from:
* https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip
+
+or from Arm Model Zoo on GitHub.
+```bash
+sudo apt-get install git git-lfs
+git lfs install
+git clone https://github.com/arm-software/ml-zoo.git
+cd ml-zoo/models/object_detection/yolo_v3_tiny/tflite_fp32/
+./get_class_labels.sh
+cp labelmappings.txt yolo_v3_tiny_darknet_fp32.tflite $WORKSPACE/armnn/python/pyarmnn/examples/object_detection/
+```
+
+The Style Transfer has been verified to work with the following models:
+
+* style prediction model: https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
+
+* style transfer model: https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
+
## Implementing Your Own Network
The examples provide support for `ssd_mobilenet_v1` and `yolo_v3_tiny` models. However, the user is able to add their own network to the object detection scripts by following the steps:
@@ -116,7 +214,7 @@ The examples provide support for `ssd_mobilenet_v1` and `yolo_v3_tiny` models. H
# Application Overview
-This section provides a walkthrough of the application, explaining in detail the steps:
+This section provides a walk-through of the application, explaining in detail the steps:
1. Initialisation
2. Creating a Network
@@ -130,7 +228,7 @@ This section provides a walkthrough of the application, explaining in detail the
##### Reading from Video Source
After parsing user arguments, the chosen video file or stream is loaded into an OpenCV `cv2.VideoCapture()` object. We use this object to capture frames from the source using the `read()` function.
-The `VideoCapture` object also tells us information about the source, such as the framerate and resolution of the input video. Using this information, we create a `cv2.VideoWriter()` object which will be used at the end of every loop to write the processed frame to an output video file of the same format as the input.
+The `VideoCapture` object also tells us information about the source, such as the frame-rate and resolution of the input video. Using this information, we create a `cv2.VideoWriter()` object which will be used at the end of every loop to write the processed frame to an output video file of the same format as the input.
##### Preparing Labels and Model Specific Functions
In order to interpret the result of running inference on the loaded network, it is required to load the labels associated with the model. In the provided example code, the `dict_labels()` function creates a dictionary that is keyed on the classification index at the output node of the model, with values of the dictionary corresponding to a label and a randomly generated RGB color. This ensures that each class has a unique color which will prove helpful when plotting the bounding boxes of various detected objects in a frame.
@@ -174,6 +272,10 @@ This preprocessing step consists of swapping channels (BGR to RGB in this exampl
##### Making Input and Output Tensors
To produce the workload tensors, calling the functions `make_input_tensors()` and `make_output_tensors()` will return the input and output tensors respectively.
+#### Creating a style bottleneck - Style prediction
+If the user decides to use style transfer, a style transfer constructor will be called to create a style bottleneck.
+To create a style bottleneck, the style transfer executor will call a style_predict function, which requires a style prediction executor, and an artistic style image.
+The style image must be preprocssed to (1, 256, 256, 3) to fit the style predict executor which will then perform inference to create a style bottleneck.
### Executing Inference
After making the workload tensors, a compute device performs inference for the loaded network using the `EnqueueWorkload()` function of the runtime context. By calling the `workload_tensors_to_ndarray()` function, we obtain the results from inference as a list of `ndarrays`.
@@ -194,3 +296,9 @@ The detection results are always returned as a list in the form `[class index, [
##### Drawing Bounding Boxes
With the obtained results and using `draw_bounding_boxes()`, we are able to draw bounding boxes around detected objects and add the associated label and confidence score. The labels dictionary created earlier uses the class index of the detected object as a key to return the associated label and color for that class. The resize factor defined at the beginning scales the bounding box coordinates to their correct positions in the original frame. The processed frames are written to file or displayed in a separate window.
+
+##### Creating Stylized Detections
+Using the detections, we are able to send them as an input to the style transfer executor to create stylized detections using the style bottleneck tensor that was calculated in the style prediction process.
+Each detection will be cropped from the frame, and then preprocessed to (1, 384, 384, 3) to fit the style transfer executor.
+The style transfer executor will use the style bottleneck and the preprocessed content frame to create an artistic stylized frame.
+The labels dictionary created earlier uses the class index of the detected object as a key to return the associated label, which is used to identify if it's equal to the style transfer class. The resize factor defined at the beginning scales the bounding box coordinates to their correct positions in the original frame. The processed frames are written to file or displayed in a separate window.
diff --git a/python/pyarmnn/examples/object_detection/requirements.txt b/python/pyarmnn/examples/object_detection/requirements.txt
index 717a536a0e..01f2d181da 100644
--- a/python/pyarmnn/examples/object_detection/requirements.txt
+++ b/python/pyarmnn/examples/object_detection/requirements.txt
@@ -1,2 +1,4 @@
+--extra-index-url https://google-coral.github.io/py-repo/
numpy>=1.19.2
tqdm>=4.47.0
+tflite_runtime==2.5.0.post1
diff --git a/python/pyarmnn/examples/object_detection/run_video_file.py b/python/pyarmnn/examples/object_detection/run_video_file.py
index 52f19d2c15..b5140d0489 100644
--- a/python/pyarmnn/examples/object_detection/run_video_file.py
+++ b/python/pyarmnn/examples/object_detection/run_video_file.py
@@ -1,4 +1,4 @@
-# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2020-2022 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""
@@ -8,6 +8,7 @@ bounding boxes and labels around detected objects, and saves the processed video
import os
import sys
+
script_dir = os.path.dirname(__file__)
sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
@@ -17,12 +18,12 @@ from argparse import ArgumentParser
from ssd import ssd_processing, ssd_resize_factor
from yolo import yolo_processing, yolo_resize_factor
-from utils import dict_labels
+from utils import dict_labels, Profiling
from cv_utils import init_video_file_capture, preprocess, draw_bounding_boxes
-from network_executor import ArmnnNetworkExecutor
+import style_transfer
-def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
+def get_model_processing(model_name: str, video: cv2.VideoCapture, input_data_shape: tuple):
"""
Gets model-specific information such as model labels and decoding and processing functions.
The user can include their own network and functions by adding another statement.
@@ -30,7 +31,7 @@ def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding
Args:
model_name: Name of type of supported model.
video: Video capture object, contains information about data source.
- input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
+ input_data_shape: Contains shape of model input layer.
Returns:
Model labels, decoding and processing functions.
@@ -38,32 +39,75 @@ def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding
if model_name == 'ssd_mobilenet_v1':
return ssd_processing, ssd_resize_factor(video)
elif model_name == 'yolo_v3_tiny':
- return yolo_processing, yolo_resize_factor(video, input_binding_info)
+ return yolo_processing, yolo_resize_factor(video, input_data_shape)
else:
raise ValueError(f'{model_name} is not a valid model name')
def main(args):
+ enable_profile = args.profiling_enabled == "true"
+ action_profiler = Profiling(enable_profile)
+ overall_profiler = Profiling(enable_profile)
+ overall_profiler.profiling_start()
+ action_profiler.profiling_start()
+
+ if args.tflite_delegate_path is not None:
+ from network_executor_tflite import TFLiteNetworkExecutor as NetworkExecutor
+ exec_input_args = (args.model_file_path, args.preferred_backends, args.tflite_delegate_path)
+ else:
+ from network_executor import ArmnnNetworkExecutor as NetworkExecutor
+ exec_input_args = (args.model_file_path, args.preferred_backends)
+
+ executor = NetworkExecutor(*exec_input_args)
+ action_profiler.profiling_stop_and_print_us("Executor initialization")
+
+ action_profiler.profiling_start()
video, video_writer, frame_count = init_video_file_capture(args.video_file_path, args.output_video_file_path)
+ process_output, resize_factor = get_model_processing(args.model_name, video, executor.get_shape())
+ action_profiler.profiling_stop_and_print_us("Video initialization")
- executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
- process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
labels = dict_labels(args.label_path, include_rgb=True)
+ if all(element is not None for element in [args.style_predict_model_file_path,
+ args.style_transfer_model_file_path,
+ args.style_image_path, args.style_transfer_class]):
+ style_image = cv2.imread(args.style_image_path)
+ action_profiler.profiling_start()
+ style_transfer_executor = style_transfer.StyleTransfer(args.style_predict_model_file_path,
+ args.style_transfer_model_file_path,
+ style_image, args.preferred_backends,
+ args.tflite_delegate_path)
+ action_profiler.profiling_stop_and_print_us("Style Transfer Executor initialization")
+
for _ in tqdm(frame_count, desc='Processing frames'):
frame_present, frame = video.read()
if not frame_present:
continue
model_name = args.model_name
if model_name == "ssd_mobilenet_v1":
- input_tensors = preprocess(frame, executor.input_binding_info, True)
+ input_data = preprocess(frame, executor.get_data_type(), executor.get_shape(), True)
else:
- input_tensors = preprocess(frame, executor.input_binding_info, False)
- output_result = executor.run(input_tensors)
+ input_data = preprocess(frame, executor.get_data_type(), executor.get_shape(), False)
+
+ action_profiler.profiling_start()
+ output_result = executor.run([input_data])
+ action_profiler.profiling_stop_and_print_us("Running inference")
+
detections = process_output(output_result)
- draw_bounding_boxes(frame, detections, resize_factor, labels)
+
+ if all(element is not None for element in [args.style_predict_model_file_path,
+ args.style_transfer_model_file_path,
+ args.style_image_path, args.style_transfer_class]):
+ action_profiler.profiling_start()
+ frame = style_transfer.create_stylized_detection(style_transfer_executor, args.style_transfer_class,
+ frame, detections, resize_factor, labels)
+ action_profiler.profiling_stop_and_print_us("Running Style Transfer")
+ else:
+ draw_bounding_boxes(frame, detections, resize_factor, labels)
+
video_writer.write(frame)
print('Finished processing frames')
+ overall_profiler.profiling_stop_and_print_us("Total compute time")
video.release(), video_writer.release()
@@ -83,5 +127,21 @@ if __name__ == '__main__':
help='Takes the preferred backends in preference order, separated by whitespace, '
'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
'Defaults to [CpuAcc, CpuRef]')
+ parser.add_argument('--tflite_delegate_path', type=str,
+ help='Enter TensorFlow Lite Delegate file path (.so file). If not entered,'
+ 'will use armnn executor')
+ parser.add_argument('--profiling_enabled', type=str,
+ help='[OPTIONAL] Enabling this option will print important ML related milestones timing'
+ 'information in micro-seconds. By default, this option is disabled.'
+ 'Accepted options are true/false.')
+ parser.add_argument('--style_predict_model_file_path', type=str,
+ help='Path to the style prediction model to use')
+ parser.add_argument('--style_transfer_model_file_path', type=str,
+ help='Path to the style transfer model to use')
+ parser.add_argument('--style_image_path', type=str,
+ help='Path to the style image to create stylized frames')
+ parser.add_argument('--style_transfer_class', type=str,
+ help='A class to transform its style')
+
args = parser.parse_args()
main(args)
diff --git a/python/pyarmnn/examples/object_detection/run_video_stream.py b/python/pyarmnn/examples/object_detection/run_video_stream.py
index dba615b97e..7b6ef253b2 100644
--- a/python/pyarmnn/examples/object_detection/run_video_stream.py
+++ b/python/pyarmnn/examples/object_detection/run_video_stream.py
@@ -9,20 +9,20 @@ and displays a window with the latest processed frame.
import os
import sys
+
script_dir = os.path.dirname(__file__)
sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
import cv2
from argparse import ArgumentParser
-
from ssd import ssd_processing, ssd_resize_factor
from yolo import yolo_processing, yolo_resize_factor
-from utils import dict_labels
+from utils import dict_labels, Profiling
from cv_utils import init_video_stream_capture, preprocess, draw_bounding_boxes
-from network_executor import ArmnnNetworkExecutor
+import style_transfer
-def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
+def get_model_processing(model_name: str, video: cv2.VideoCapture, input_data_shape: tuple):
"""
Gets model-specific information such as model labels and decoding and processing functions.
The user can include their own network and functions by adding another statement.
@@ -30,7 +30,7 @@ def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding
Args:
model_name: Name of type of supported model.
video: Video capture object, contains information about data source.
- input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
+ input_data_shape: Contains shape of model input layer, used for scaling bounding boxes.
Returns:
Model labels, decoding and processing functions.
@@ -38,33 +38,71 @@ def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding
if model_name == 'ssd_mobilenet_v1':
return ssd_processing, ssd_resize_factor(video)
elif model_name == 'yolo_v3_tiny':
- return yolo_processing, yolo_resize_factor(video, input_binding_info)
+ return yolo_processing, yolo_resize_factor(video, input_data_shape)
else:
raise ValueError(f'{model_name} is not a valid model name')
def main(args):
- video = init_video_stream_capture(args.video_source)
- executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
+ enable_profile = args.profiling_enabled == "true"
+ action_profiler = Profiling(enable_profile)
+ action_profiler.profiling_start()
+
+ if args.tflite_delegate_path is not None:
+ from network_executor_tflite import TFLiteNetworkExecutor as NetworkExecutor
+ exec_input_args = (args.model_file_path, args.preferred_backends, args.tflite_delegate_path)
+ else:
+ from network_executor import ArmnnNetworkExecutor as NetworkExecutor
+ exec_input_args = (args.model_file_path, args.preferred_backends)
+
+ executor = NetworkExecutor(*exec_input_args)
+ action_profiler.profiling_stop_and_print_us("Executor initialization")
+
+ action_profiler.profiling_start()
+ video = init_video_stream_capture(args.video_source)
+ action_profiler.profiling_stop_and_print_us("Video initialization")
model_name = args.model_name
- process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
+ process_output, resize_factor = get_model_processing(args.model_name, video, executor.get_shape())
labels = dict_labels(args.label_path, include_rgb=True)
+ if all(element is not None for element in [args.style_predict_model_file_path,
+ args.style_transfer_model_file_path,
+ args.style_image_path, args.style_transfer_class]):
+ style_image = cv2.imread(args.style_image_path)
+ action_profiler.profiling_start()
+ style_transfer_executor = style_transfer.StyleTransfer(args.style_predict_model_file_path,
+ args.style_transfer_model_file_path,
+ style_image, args.preferred_backends,
+ args.tflite_delegate_path)
+ action_profiler.profiling_stop_and_print_us("Style Transfer Executor initialization")
+
while True:
frame_present, frame = video.read()
frame = cv2.flip(frame, 1) # Horizontally flip the frame
if not frame_present:
raise RuntimeError('Error reading frame from video stream')
+ action_profiler.profiling_start()
if model_name == "ssd_mobilenet_v1":
- input_tensors = preprocess(frame, executor.input_binding_info, True)
+ input_data = preprocess(frame, executor.get_data_type(), executor.get_shape(), True)
else:
- input_tensors = preprocess(frame, executor.input_binding_info, False)
- print("Running inference...")
- output_result = executor.run(input_tensors)
+ input_data = preprocess(frame, executor.get_data_type(), executor.get_shape(), False)
+
+ output_result = executor.run([input_data])
+ if not enable_profile:
+ print("Running inference...")
+ action_profiler.profiling_stop_and_print_us("Running inference...")
detections = process_output(output_result)
- draw_bounding_boxes(frame, detections, resize_factor, labels)
+ if all(element is not None for element in [args.style_predict_model_file_path,
+ args.style_transfer_model_file_path,
+ args.style_image_path, args.style_transfer_class]):
+ action_profiler.profiling_start()
+ frame = style_transfer.create_stylized_detection(style_transfer_executor, args.style_transfer_class,
+ frame, detections, resize_factor, labels)
+ action_profiler.profiling_stop_and_print_us("Running Style Transfer")
+ else:
+ draw_bounding_boxes(frame, detections, resize_factor, labels)
cv2.imshow('PyArmNN Object Detection Demo', frame)
if cv2.waitKey(1) == 27:
print('\nExit key activated. Closing video...')
@@ -86,5 +124,21 @@ if __name__ == '__main__':
help='Takes the preferred backends in preference order, separated by whitespace, '
'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
'Defaults to [CpuAcc, CpuRef]')
+ parser.add_argument('--tflite_delegate_path', type=str,
+ help='Enter TensorFlow Lite Delegate file path (.so file). If not entered,'
+ 'will use armnn executor')
+ parser.add_argument('--profiling_enabled', type=str,
+ help='[OPTIONAL] Enabling this option will print important ML related milestones timing'
+ 'information in micro-seconds. By default, this option is disabled.'
+ 'Accepted options are true/false.')
+ parser.add_argument('--style_predict_model_file_path', type=str,
+ help='Path to the style prediction model to use')
+ parser.add_argument('--style_transfer_model_file_path', type=str,
+ help='Path to the style transfer model to use')
+ parser.add_argument('--style_image_path', type=str,
+ help='Path to the style image to create stylized frames')
+ parser.add_argument('--style_transfer_class', type=str,
+ help='A class to transform its style')
+
args = parser.parse_args()
main(args)
diff --git a/python/pyarmnn/examples/object_detection/style_transfer.py b/python/pyarmnn/examples/object_detection/style_transfer.py
new file mode 100644
index 0000000000..eda618e31a
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/style_transfer.py
@@ -0,0 +1,138 @@
+# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import numpy as np
+import urllib.request
+import cv2
+import network_executor_tflite
+import cv_utils
+
+
+def style_transfer_postprocess(preprocessed_frame: np.ndarray, image_shape: tuple):
+ """
+ Resizes the output frame of style transfer network and changes the color back to original configuration
+
+ Args:
+ preprocessed_frame: A preprocessed frame after style transfer.
+ image_shape: Contains shape of the original frame before preprocessing.
+
+ Returns:
+ Resizing factor to scale coordinates according to image_shape.
+ """
+
+ postprocessed_frame = np.squeeze(preprocessed_frame, axis=0)
+ # select original height and width from image_shape
+ frame_height = image_shape[0]
+ frame_width = image_shape[1]
+ postprocessed_frame = cv2.resize(postprocessed_frame, (frame_width, frame_height)).astype("float32") * 255
+ postprocessed_frame = cv2.cvtColor(postprocessed_frame, cv2.COLOR_RGB2BGR)
+
+ return postprocessed_frame
+
+
+def create_stylized_detection(style_transfer_executor, style_transfer_class, frame: np.ndarray,
+ detections: list, resize_factor, labels: dict):
+ """
+ Perform style transfer on a detected class in a frame
+
+ Args:
+ style_transfer_executor: The style transfer executor
+ style_transfer_class: The class detected to change its style
+ frame: The original captured frame from video source.
+ detections: A list of detected objects in the form [class, [box positions], confidence].
+ resize_factor: Resizing factor to scale box coordinates to output frame size.
+ labels: Dictionary of labels and colors keyed on the classification index.
+ """
+ for detection in detections:
+ class_idx, box, confidence = [d for d in detection]
+ label = labels[class_idx][0]
+ if label.lower() == style_transfer_class.lower():
+ # Obtain frame size and resized bounding box positions
+ frame_height, frame_width = frame.shape[:2]
+ x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
+
+ # Ensure box stays within the frame
+ x_min, y_min = max(0, x_min), max(0, y_min)
+ x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
+
+ # Crop only the detected object
+ cropped_frame = cv_utils.crop_bounding_box_object(frame, x_min, y_min, x_max, y_max)
+
+ # Run style_transfer on preprocessed_frame
+ stylized_frame = style_transfer_executor.run_style_transfer(cropped_frame)
+
+ # Paste stylized_frame on the original frame in the correct place
+ frame[int(y_min)+1:int(y_max), int(x_min)+1:int(x_max)] = stylized_frame
+
+ return frame
+
+
+class StyleTransfer:
+
+ def __init__(self, style_predict_model_path: str, style_transfer_model_path: str,
+ style_image: np.ndarray, backends: list, delegate_path: str):
+ """
+ Creates an inference executor for style predict network, style transfer network,
+ list of backends and a style image.
+
+ Args:
+ style_predict_model_path: model which is used to create a style bottleneck
+ style_transfer_model_path: model which is used to create stylized frames
+ style_image: an image to create the style bottleneck
+ backends: List of backends to optimize network.
+ delegate_path: tflite delegate file path (.so).
+ """
+
+ self.style_predict_executor = network_executor_tflite.TFLiteNetworkExecutor(style_predict_model_path, backends,
+ delegate_path)
+ self.style_transfer_executor = network_executor_tflite.TFLiteNetworkExecutor(style_transfer_model_path,
+ backends,
+ delegate_path)
+ self.style_bottleneck = self.run_style_predict(style_image)
+
+ def get_style_predict_executor_shape(self):
+ """
+ Get the input shape of the initiated network.
+
+ Returns:
+ tuple: The Shape of the network input.
+ """
+ return self.style_predict_executor.get_shape()
+
+ # Function to run create a style_bottleneck using preprocessed style image.
+ def run_style_predict(self, style_image):
+ """
+ Creates bottleneck tensor for a given style image.
+
+ Args:
+ style_image: an image to create the style bottleneck
+
+ Returns:
+ style bottleneck tensor
+ """
+ # The style image has to be preprocessed to (1, 256, 256, 3)
+ preprocessed_style_image = cv_utils.preprocess(style_image, self.style_predict_executor.get_data_type(),
+ self.style_predict_executor.get_shape(), True, keep_aspect_ratio=False)
+ # output[0] is the style bottleneck tensor
+ style_bottleneck = self.style_predict_executor.run([preprocessed_style_image])[0]
+
+ return style_bottleneck
+
+ # Run style transform on preprocessed style image
+ def run_style_transfer(self, content_image):
+ """
+ Runs inference for given content_image and style bottleneck to create a stylized image.
+
+ Args:
+ content_image:a content image to stylize
+ """
+ # The content image has to be preprocessed to (1, 384, 384, 3)
+ preprocessed_style_image = cv_utils.preprocess(content_image, np.float32,
+ self.style_transfer_executor.get_shape(), True, keep_aspect_ratio=False)
+
+ # Transform content image. output[0] is the stylized image
+ stylized_image = self.style_transfer_executor.run([preprocessed_style_image, self.style_bottleneck])[0]
+
+ post_stylized_image = style_transfer_postprocess(stylized_image, content_image.shape)
+
+ return post_stylized_image
diff --git a/python/pyarmnn/examples/object_detection/yolo.py b/python/pyarmnn/examples/object_detection/yolo.py
index 1748d158a2..e76ed7b2f4 100644
--- a/python/pyarmnn/examples/object_detection/yolo.py
+++ b/python/pyarmnn/examples/object_detection/yolo.py
@@ -80,19 +80,19 @@ def yolo_processing(output: np.ndarray, confidence_threshold=0.40, iou_threshold
return nms_det
-def yolo_resize_factor(video: cv2.VideoCapture, input_binding_info: tuple):
+def yolo_resize_factor(video: cv2.VideoCapture, input_data_shape: tuple):
"""
Gets a multiplier to scale the bounding box positions to
their correct position in the frame.
Args:
video: Video capture object, contains information about data source.
- input_binding_info: Contains shape of model input layer.
+ input_data_shape: Contains shape of model input layer.
Returns:
Resizing factor to scale box coordinates to output frame size.
"""
frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
- model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
+ _, model_height, model_width, _= input_data_shape
return max(frame_height, frame_width) / max(model_height, model_width)
diff --git a/python/pyarmnn/examples/speech_recognition/README.md b/python/pyarmnn/examples/speech_recognition/README.md
index 2cdc8691d2..854cdaf03b 100644
--- a/python/pyarmnn/examples/speech_recognition/README.md
+++ b/python/pyarmnn/examples/speech_recognition/README.md
@@ -151,7 +151,7 @@ for i in range(features.shape[1]):
# audio_utils.py
# Quantize the input data and create input tensors with PyArmNN
input_tensor = quantize_input(input_tensor, input_binding_info)
-input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+input_tensors = ann.make_input_tensors([input_binding_info], [input_data])
```
Note: `ArmnnNetworkExecutor` has already created the output tensors for you.
@@ -172,4 +172,4 @@ Having now gained a solid understanding of performing automatic speech recogniti
An important step to improving accuracy of the generated output sentences is by providing cleaner data to the network. This can be done by including additional preprocessing steps such as noise reduction of your audio data.
-In this application, we had used a greedy decoder to decode the integer-encoded output however, better results can be achieved by implementing a beam search decoder. You may even try adding a language model at the end to aim to correct any spelling mistakes the model may produce. \ No newline at end of file
+In this application, we had used a greedy decoder to decode the integer-encoded output however, better results can be achieved by implementing a beam search decoder. You may even try adding a language model at the end to aim to correct any spelling mistakes the model may produce.
diff --git a/python/pyarmnn/examples/speech_recognition/run_audio_file.py b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
index 0430f68c16..ddf6cb704c 100644
--- a/python/pyarmnn/examples/speech_recognition/run_audio_file.py
+++ b/python/pyarmnn/examples/speech_recognition/run_audio_file.py
@@ -12,7 +12,7 @@ sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
from argparse import ArgumentParser
from network_executor import ArmnnNetworkExecutor
-from utils import prepare_input_tensors
+from utils import prepare_input_data
from audio_capture import AudioCaptureParams, capture_audio
from audio_utils import decode_text, display_text
from wav2letter_mfcc import Wav2LetterMFCC, W2LAudioPreprocessor
@@ -78,10 +78,11 @@ def main(args):
print("Processing Audio Frames...")
for audio_data in buffer:
# Prepare the input Tensors
- input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+ input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0),
+ network.get_input_quantization_offset(0), preprocessor)
# Run inference
- output_result = network.run(input_tensors)
+ output_result = network.run([input_data])
# Slice and Decode the text, and store the right context
current_r_context, text = decode_text(is_first_window, labels, output_result)
diff --git a/python/pyarmnn/examples/tests/conftest.py b/python/pyarmnn/examples/tests/conftest.py
index b7fa73b852..4f1ac5f379 100644
--- a/python/pyarmnn/examples/tests/conftest.py
+++ b/python/pyarmnn/examples/tests/conftest.py
@@ -20,20 +20,38 @@ def test_data_folder():
data_dir = os.path.join(script_dir, "testdata")
if not os.path.exists(data_dir):
os.mkdir(data_dir)
+
+ sys_arch = os.uname().machine
+ if sys_arch == "x86_64":
+ libarmnn_url = "https://github.com/ARM-software/armnn/releases/download/v21.11/ArmNN-linux-x86_64.tar.gz"
+ else:
+ libarmnn_url = "https://github.com/ARM-software/armnn/releases/download/v21.11/ArmNN-linux-aarch64.tar.gz"
+
+
files_to_download = ["https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/messi5.jpg",
"https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/basketball1.png",
"https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/Megamind.avi",
"https://github.com/ARM-software/ML-zoo/raw/master/models/object_detection/ssd_mobilenet_v1/tflite_uint8/ssd_mobilenet_v1.tflite",
"https://git.mlplatform.org/ml/ethos-u/ml-embedded-evaluation-kit.git/plain/resources/kws/samples/yes.wav",
- "https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav"
+ "https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav",
+ "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite",
+ "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite",
+ libarmnn_url
]
for file in files_to_download:
path, filename = ntpath.split(file)
+ if filename == '1?lite-format=tflite' and 'prediction' in file:
+ filename = 'style_predict.tflite'
+ elif filename == '1?lite-format=tflite' and 'transfer' in file:
+ filename = 'style_transfer.tflite'
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print("\nDownloading test file: " + file_path + "\n")
urllib.request.urlretrieve(file, file_path)
+ path, filename = ntpath.split(libarmnn_url)
+ file_path = os.path.join(data_dir, filename)
+ os.system(f"tar -xvzf {file_path} -C {data_dir} ")
return data_dir
diff --git a/python/pyarmnn/examples/tests/context.py b/python/pyarmnn/examples/tests/context.py
index a678f94178..6b9439d353 100644
--- a/python/pyarmnn/examples/tests/context.py
+++ b/python/pyarmnn/examples/tests/context.py
@@ -10,6 +10,8 @@ sys.path.insert(0, os.path.join(script_dir, '..'))
import common.cv_utils as cv_utils
import common.network_executor as network_executor
+import common.network_executor_tflite as network_executor_tflite
+
import common.utils as utils
import common.audio_capture as audio_capture
import common.mfcc as mfcc
@@ -17,6 +19,4 @@ import common.mfcc as mfcc
import speech_recognition.wav2letter_mfcc as wav2letter_mfcc
import speech_recognition.audio_utils as audio_utils
-
-
-
+import object_detection.style_transfer as style_transfer
diff --git a/python/pyarmnn/examples/tests/test_common_utils.py b/python/pyarmnn/examples/tests/test_common_utils.py
index 28d68ea235..254eba63f8 100644
--- a/python/pyarmnn/examples/tests/test_common_utils.py
+++ b/python/pyarmnn/examples/tests/test_common_utils.py
@@ -2,9 +2,13 @@
# SPDX-License-Identifier: MIT
import os
+import time
+import cv2
+import numpy as np
from context import cv_utils
from context import utils
+from utils import Profiling
def test_get_source_encoding(test_data_folder):
@@ -17,3 +21,22 @@ def test_read_existing_labels_file(test_data_folder):
label_file = os.path.join(test_data_folder, "labelmap.txt")
labels_map = utils.dict_labels(label_file)
assert labels_map is not None
+
+
+def test_preprocess(test_data_folder):
+ content_image = "messi5.jpg"
+ target_shape = (1, 256, 256, 3)
+ padding = True
+ image = cv2.imread(os.path.join(test_data_folder, content_image))
+ image = cv_utils.preprocess(image, np.float32, target_shape, True, padding)
+
+ assert image.shape == target_shape
+
+
+def test_profiling():
+ profiler = Profiling(True)
+ profiler.profiling_start()
+ time.sleep(1)
+ period = profiler.profiling_stop_and_print_us("Sleep for 1 second")
+ assert (1_000_000 < period < 1_002_000)
+
diff --git a/python/pyarmnn/examples/tests/test_network_executor.py b/python/pyarmnn/examples/tests/test_network_executor.py
index c124b11382..f266c16537 100644
--- a/python/pyarmnn/examples/tests/test_network_executor.py
+++ b/python/pyarmnn/examples/tests/test_network_executor.py
@@ -2,23 +2,35 @@
# SPDX-License-Identifier: MIT
import os
-
+import pytest
import cv2
+import numpy as np
from context import network_executor
+from context import network_executor_tflite
from context import cv_utils
-
-def test_execute_network(test_data_folder):
+@pytest.mark.parametrize("executor_name", ["armnn", "tflite"])
+def test_execute_network(test_data_folder, executor_name):
model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite")
backends = ["CpuAcc", "CpuRef"]
+ if executor_name == "armnn":
+ executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
+ elif executor_name == "tflite":
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+ executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path)
+ else:
+ raise f"unsupported executor_name: {executor_name}"
- executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
- input_tensors = cv_utils.preprocess(img, executor.input_binding_info, True)
+ resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True)
- output_result = executor.run(input_tensors)
+ output_result = executor.run([resized_img])
# Ensure it detects a person
classes = output_result[1]
assert classes[0][0] == 0
+
+ # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network
+ assert executor.get_data_type() == np.uint8
+ assert executor.get_shape() == (1, 300, 300, 3)
diff --git a/python/pyarmnn/examples/tests/test_style_transfer.py b/python/pyarmnn/examples/tests/test_style_transfer.py
new file mode 100644
index 0000000000..fae4a427f0
--- /dev/null
+++ b/python/pyarmnn/examples/tests/test_style_transfer.py
@@ -0,0 +1,70 @@
+# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import cv2
+import numpy as np
+
+from context import style_transfer
+from context import cv_utils
+
+
+def test_style_transfer_postprocess(test_data_folder):
+ content_image = "messi5.jpg"
+ target_shape = (1,256,256,3)
+ keep_aspect_ratio = False
+ image = cv2.imread(os.path.join(test_data_folder, content_image))
+ original_shape = image.shape
+ preprocessed_image = cv_utils.preprocess(image, np.float32, target_shape, False, keep_aspect_ratio)
+ assert preprocessed_image.shape == target_shape
+
+ postprocess_image = style_transfer.style_transfer_postprocess(preprocessed_image, original_shape)
+ assert postprocess_image.shape == original_shape
+
+
+def test_style_transfer(test_data_folder):
+ style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite")
+ style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite")
+ backends = ["CpuAcc", "CpuRef"]
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+ image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
+
+ style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path,
+ image, backends, delegate_path)
+
+ assert style_transfer_executor.get_style_predict_executor_shape() == (1, 256, 256, 3)
+
+def test_run_style_transfer(test_data_folder):
+ style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite")
+ style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite")
+ backends = ["CpuAcc", "CpuRef"]
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+ style_image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
+ content_image = cv2.imread(os.path.join(test_data_folder, "basketball1.png"))
+
+ style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path,
+ style_image, backends, delegate_path)
+
+ stylized_image = style_transfer_executor.run_style_transfer(content_image)
+ assert stylized_image.shape == content_image.shape
+
+
+def test_create_stylized_detection(test_data_folder):
+ style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite")
+ style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite")
+ backends = ["CpuAcc", "CpuRef"]
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+
+ style_image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
+ content_image = cv2.imread(os.path.join(test_data_folder, "basketball1.png"))
+ detections = [(0.0, [0.16745174, 0.15101701, 0.5371381, 0.74165875], 0.87597656)]
+ labels = {0: ('person', (50.888902345757494, 129.61878417939724, 207.2891028294508)),
+ 1: ('bicycle', (55.055339686943654, 55.828708219750574, 43.550389695374676)),
+ 2: ('car', (95.33096265662336, 194.872841553212, 218.58516479057758))}
+ style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path,
+ style_image, backends, delegate_path)
+
+ stylized_image = style_transfer.create_stylized_detection(style_transfer_executor, 'person', content_image,
+ detections, 720, labels)
+
+ assert stylized_image.shape == content_image.shape