aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/tests
diff options
context:
space:
mode:
authorRaviv Shalev <raviv.shalev@arm.com>2021-12-07 15:18:09 +0200
committerTeresaARM <teresa.charlinreyes@arm.com>2022-04-13 15:33:31 +0000
commit97ddc06e52fbcabfd8ede7a00e9494c663186b92 (patch)
tree43c84d352c3a67aa45d89760fba6c79b81c8f8dc /python/pyarmnn/examples/tests
parent2f0ddb67d8f9267ab600a8a26308cab32f9e16ac (diff)
downloadarmnn-97ddc06e52fbcabfd8ede7a00e9494c663186b92.tar.gz
MLECO-2493 Add python OD example with TFLite delegate
Signed-off-by: Raviv Shalev <raviv.shalev@arm.com> Change-Id: I25fcccbf912be0c5bd4fbfd2e97552341958af35
Diffstat (limited to 'python/pyarmnn/examples/tests')
-rw-r--r--python/pyarmnn/examples/tests/conftest.py20
-rw-r--r--python/pyarmnn/examples/tests/context.py6
-rw-r--r--python/pyarmnn/examples/tests/test_common_utils.py23
-rw-r--r--python/pyarmnn/examples/tests/test_network_executor.py24
-rw-r--r--python/pyarmnn/examples/tests/test_style_transfer.py70
5 files changed, 133 insertions, 10 deletions
diff --git a/python/pyarmnn/examples/tests/conftest.py b/python/pyarmnn/examples/tests/conftest.py
index b7fa73b852..4f1ac5f379 100644
--- a/python/pyarmnn/examples/tests/conftest.py
+++ b/python/pyarmnn/examples/tests/conftest.py
@@ -20,20 +20,38 @@ def test_data_folder():
data_dir = os.path.join(script_dir, "testdata")
if not os.path.exists(data_dir):
os.mkdir(data_dir)
+
+ sys_arch = os.uname().machine
+ if sys_arch == "x86_64":
+ libarmnn_url = "https://github.com/ARM-software/armnn/releases/download/v21.11/ArmNN-linux-x86_64.tar.gz"
+ else:
+ libarmnn_url = "https://github.com/ARM-software/armnn/releases/download/v21.11/ArmNN-linux-aarch64.tar.gz"
+
+
files_to_download = ["https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/messi5.jpg",
"https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/basketball1.png",
"https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/Megamind.avi",
"https://github.com/ARM-software/ML-zoo/raw/master/models/object_detection/ssd_mobilenet_v1/tflite_uint8/ssd_mobilenet_v1.tflite",
"https://git.mlplatform.org/ml/ethos-u/ml-embedded-evaluation-kit.git/plain/resources/kws/samples/yes.wav",
- "https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav"
+ "https://raw.githubusercontent.com/Azure-Samples/cognitive-services-speech-sdk/master/sampledata/audiofiles/myVoiceIsMyPassportVerifyMe04.wav",
+ "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite",
+ "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite",
+ libarmnn_url
]
for file in files_to_download:
path, filename = ntpath.split(file)
+ if filename == '1?lite-format=tflite' and 'prediction' in file:
+ filename = 'style_predict.tflite'
+ elif filename == '1?lite-format=tflite' and 'transfer' in file:
+ filename = 'style_transfer.tflite'
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print("\nDownloading test file: " + file_path + "\n")
urllib.request.urlretrieve(file, file_path)
+ path, filename = ntpath.split(libarmnn_url)
+ file_path = os.path.join(data_dir, filename)
+ os.system(f"tar -xvzf {file_path} -C {data_dir} ")
return data_dir
diff --git a/python/pyarmnn/examples/tests/context.py b/python/pyarmnn/examples/tests/context.py
index a678f94178..6b9439d353 100644
--- a/python/pyarmnn/examples/tests/context.py
+++ b/python/pyarmnn/examples/tests/context.py
@@ -10,6 +10,8 @@ sys.path.insert(0, os.path.join(script_dir, '..'))
import common.cv_utils as cv_utils
import common.network_executor as network_executor
+import common.network_executor_tflite as network_executor_tflite
+
import common.utils as utils
import common.audio_capture as audio_capture
import common.mfcc as mfcc
@@ -17,6 +19,4 @@ import common.mfcc as mfcc
import speech_recognition.wav2letter_mfcc as wav2letter_mfcc
import speech_recognition.audio_utils as audio_utils
-
-
-
+import object_detection.style_transfer as style_transfer
diff --git a/python/pyarmnn/examples/tests/test_common_utils.py b/python/pyarmnn/examples/tests/test_common_utils.py
index 28d68ea235..254eba63f8 100644
--- a/python/pyarmnn/examples/tests/test_common_utils.py
+++ b/python/pyarmnn/examples/tests/test_common_utils.py
@@ -2,9 +2,13 @@
# SPDX-License-Identifier: MIT
import os
+import time
+import cv2
+import numpy as np
from context import cv_utils
from context import utils
+from utils import Profiling
def test_get_source_encoding(test_data_folder):
@@ -17,3 +21,22 @@ def test_read_existing_labels_file(test_data_folder):
label_file = os.path.join(test_data_folder, "labelmap.txt")
labels_map = utils.dict_labels(label_file)
assert labels_map is not None
+
+
+def test_preprocess(test_data_folder):
+ content_image = "messi5.jpg"
+ target_shape = (1, 256, 256, 3)
+ padding = True
+ image = cv2.imread(os.path.join(test_data_folder, content_image))
+ image = cv_utils.preprocess(image, np.float32, target_shape, True, padding)
+
+ assert image.shape == target_shape
+
+
+def test_profiling():
+ profiler = Profiling(True)
+ profiler.profiling_start()
+ time.sleep(1)
+ period = profiler.profiling_stop_and_print_us("Sleep for 1 second")
+ assert (1_000_000 < period < 1_002_000)
+
diff --git a/python/pyarmnn/examples/tests/test_network_executor.py b/python/pyarmnn/examples/tests/test_network_executor.py
index c124b11382..f266c16537 100644
--- a/python/pyarmnn/examples/tests/test_network_executor.py
+++ b/python/pyarmnn/examples/tests/test_network_executor.py
@@ -2,23 +2,35 @@
# SPDX-License-Identifier: MIT
import os
-
+import pytest
import cv2
+import numpy as np
from context import network_executor
+from context import network_executor_tflite
from context import cv_utils
-
-def test_execute_network(test_data_folder):
+@pytest.mark.parametrize("executor_name", ["armnn", "tflite"])
+def test_execute_network(test_data_folder, executor_name):
model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite")
backends = ["CpuAcc", "CpuRef"]
+ if executor_name == "armnn":
+ executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
+ elif executor_name == "tflite":
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+ executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path)
+ else:
+ raise f"unsupported executor_name: {executor_name}"
- executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
- input_tensors = cv_utils.preprocess(img, executor.input_binding_info, True)
+ resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True)
- output_result = executor.run(input_tensors)
+ output_result = executor.run([resized_img])
# Ensure it detects a person
classes = output_result[1]
assert classes[0][0] == 0
+
+ # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network
+ assert executor.get_data_type() == np.uint8
+ assert executor.get_shape() == (1, 300, 300, 3)
diff --git a/python/pyarmnn/examples/tests/test_style_transfer.py b/python/pyarmnn/examples/tests/test_style_transfer.py
new file mode 100644
index 0000000000..fae4a427f0
--- /dev/null
+++ b/python/pyarmnn/examples/tests/test_style_transfer.py
@@ -0,0 +1,70 @@
+# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import os
+import cv2
+import numpy as np
+
+from context import style_transfer
+from context import cv_utils
+
+
+def test_style_transfer_postprocess(test_data_folder):
+ content_image = "messi5.jpg"
+ target_shape = (1,256,256,3)
+ keep_aspect_ratio = False
+ image = cv2.imread(os.path.join(test_data_folder, content_image))
+ original_shape = image.shape
+ preprocessed_image = cv_utils.preprocess(image, np.float32, target_shape, False, keep_aspect_ratio)
+ assert preprocessed_image.shape == target_shape
+
+ postprocess_image = style_transfer.style_transfer_postprocess(preprocessed_image, original_shape)
+ assert postprocess_image.shape == original_shape
+
+
+def test_style_transfer(test_data_folder):
+ style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite")
+ style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite")
+ backends = ["CpuAcc", "CpuRef"]
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+ image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
+
+ style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path,
+ image, backends, delegate_path)
+
+ assert style_transfer_executor.get_style_predict_executor_shape() == (1, 256, 256, 3)
+
+def test_run_style_transfer(test_data_folder):
+ style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite")
+ style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite")
+ backends = ["CpuAcc", "CpuRef"]
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+ style_image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
+ content_image = cv2.imread(os.path.join(test_data_folder, "basketball1.png"))
+
+ style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path,
+ style_image, backends, delegate_path)
+
+ stylized_image = style_transfer_executor.run_style_transfer(content_image)
+ assert stylized_image.shape == content_image.shape
+
+
+def test_create_stylized_detection(test_data_folder):
+ style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite")
+ style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite")
+ backends = ["CpuAcc", "CpuRef"]
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+
+ style_image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
+ content_image = cv2.imread(os.path.join(test_data_folder, "basketball1.png"))
+ detections = [(0.0, [0.16745174, 0.15101701, 0.5371381, 0.74165875], 0.87597656)]
+ labels = {0: ('person', (50.888902345757494, 129.61878417939724, 207.2891028294508)),
+ 1: ('bicycle', (55.055339686943654, 55.828708219750574, 43.550389695374676)),
+ 2: ('car', (95.33096265662336, 194.872841553212, 218.58516479057758))}
+ style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path,
+ style_image, backends, delegate_path)
+
+ stylized_image = style_transfer.create_stylized_detection(style_transfer_executor, 'person', content_image,
+ detections, 720, labels)
+
+ assert stylized_image.shape == content_image.shape