diff options
Diffstat (limited to 'python/pyarmnn/examples/tests/test_network_executor.py')
-rw-r--r-- | python/pyarmnn/examples/tests/test_network_executor.py | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/python/pyarmnn/examples/tests/test_network_executor.py b/python/pyarmnn/examples/tests/test_network_executor.py index c124b11382..f266c16537 100644 --- a/python/pyarmnn/examples/tests/test_network_executor.py +++ b/python/pyarmnn/examples/tests/test_network_executor.py @@ -2,23 +2,35 @@ # SPDX-License-Identifier: MIT import os - +import pytest import cv2 +import numpy as np from context import network_executor +from context import network_executor_tflite from context import cv_utils - -def test_execute_network(test_data_folder): +@pytest.mark.parametrize("executor_name", ["armnn", "tflite"]) +def test_execute_network(test_data_folder, executor_name): model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite") backends = ["CpuAcc", "CpuRef"] + if executor_name == "armnn": + executor = network_executor.ArmnnNetworkExecutor(model_path, backends) + elif executor_name == "tflite": + delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so") + executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path) + else: + raise f"unsupported executor_name: {executor_name}" - executor = network_executor.ArmnnNetworkExecutor(model_path, backends) img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg")) - input_tensors = cv_utils.preprocess(img, executor.input_binding_info, True) + resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True) - output_result = executor.run(input_tensors) + output_result = executor.run([resized_img]) # Ensure it detects a person classes = output_result[1] assert classes[0][0] == 0 + + # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network + assert executor.get_data_type() == np.uint8 + assert executor.get_shape() == (1, 300, 300, 3) |