aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/common/tests/test_network_executor.py
blob: e27b3820784ca1d74a3729c0d73412b9c84889c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT

import os

import cv2

from context import network_executor
from context import cv_utils


def test_execute_network(test_data_folder):
    model_path = os.path.join(test_data_folder, "detect.tflite")
    backends = ["CpuAcc", "CpuRef"]

    executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
    img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
    input_tensors = cv_utils.preprocess(img, executor.input_binding_info)

    output_result = executor.run(input_tensors)

    # Ensure it detects a person
    classes = output_result[1]
    assert classes[0][0] == 0