aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/tests/test_network_executor.py
blob: f266c16537bd296cba471051f07ff9e0987641e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT

import os
import pytest
import cv2
import numpy as np

from context import network_executor
from context import network_executor_tflite
from context import cv_utils

@pytest.mark.parametrize("executor_name", ["armnn", "tflite"])
def test_execute_network(test_data_folder, executor_name):
    model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite")
    backends = ["CpuAcc", "CpuRef"]
    if executor_name == "armnn":
        executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
    elif executor_name == "tflite":
        delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
        executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path)
    else:
        raise f"unsupported executor_name: {executor_name}"

    img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
    resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True)

    output_result = executor.run([resized_img])

    # Ensure it detects a person
    classes = output_result[1]
    assert classes[0][0] == 0

    # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network
    assert executor.get_data_type() == np.uint8
    assert executor.get_shape() == (1, 300, 300, 3)