aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/tests/test_network_executor.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/tests/test_network_executor.py')
-rw-r--r--python/pyarmnn/examples/tests/test_network_executor.py24
1 files changed, 18 insertions, 6 deletions
diff --git a/python/pyarmnn/examples/tests/test_network_executor.py b/python/pyarmnn/examples/tests/test_network_executor.py
index c124b11382..f266c16537 100644
--- a/python/pyarmnn/examples/tests/test_network_executor.py
+++ b/python/pyarmnn/examples/tests/test_network_executor.py
@@ -2,23 +2,35 @@
# SPDX-License-Identifier: MIT
import os
-
+import pytest
import cv2
+import numpy as np
from context import network_executor
+from context import network_executor_tflite
from context import cv_utils
-
-def test_execute_network(test_data_folder):
+@pytest.mark.parametrize("executor_name", ["armnn", "tflite"])
+def test_execute_network(test_data_folder, executor_name):
model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite")
backends = ["CpuAcc", "CpuRef"]
+ if executor_name == "armnn":
+ executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
+ elif executor_name == "tflite":
+ delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
+ executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path)
+ else:
+ raise f"unsupported executor_name: {executor_name}"
- executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
- input_tensors = cv_utils.preprocess(img, executor.input_binding_info, True)
+ resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True)
- output_result = executor.run(input_tensors)
+ output_result = executor.run([resized_img])
# Ensure it detects a person
classes = output_result[1]
assert classes[0][0] == 0
+
+ # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network
+ assert executor.get_data_type() == np.uint8
+ assert executor.get_shape() == (1, 300, 300, 3)