aboutsummaryrefslogtreecommitdiff
path: root/driver_library/python/test/test_driver_utilities.py
blob: fc8e92113fe9af0718883100cb1932f75dd5ad7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#
# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
import pytest
import os
import ethosu_driver as driver
from ethosu_driver.inference_runner import read_npy_file_to_buf


@pytest.fixture()
def device(device_name):
    device = driver.open_device(device_name)
    yield device


@pytest.fixture()
def network(device, model_name, shared_data_folder):
    network_file = os.path.join(shared_data_folder, model_name)
    network = driver.load_model(device, network_file)
    yield network


@pytest.mark.parametrize('device_name', ['blabla'])
def test_open_device_wrong_name(device_name):
    with pytest.raises(RuntimeError) as err:
        device = driver.open_device(device_name)
    # Only check for part of the exception since the exception returns
    # absolute path which will change on different machines.
    assert 'Failed to open device' in str(err.value)


@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_network_filenotfound_exception(device, shared_data_folder):

    network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")

    with pytest.raises(RuntimeError) as err:
        driver.load_model(device, network_file)

    # Only check for part of the exception since the exception returns
    # absolute path which will change on different machines.
    assert 'Failed to open file:' in str(err.value)


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_network_ifm_size(network):
    assert network.getIfmSize() > 0


@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_allocate_buffers(device):
    buffers = driver.allocate_buffers(device, [128, 256])
    assert len(buffers) == 2
    assert buffers[0].size() == 0
    assert buffers[0].capacity() == 128
    assert buffers[1].size() == 0
    assert buffers[1].capacity() == 256


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
@pytest.mark.parametrize('ifms_file_list', [['model_ifm.npy']])
def test_set_ifm_buffers(device, network, ifms_file_list, shared_data_folder):
    full_path_input_files = []
    for input_file in ifms_file_list:
        full_path_input_files.append(os.path.join(shared_data_folder, input_file))

    ifms_data = []
    for ifm_file in full_path_input_files:
        ifms_data.append(read_npy_file_to_buf(ifm_file))

    ifms = driver.allocate_buffers(device, network.getIfmDims())
    driver.populate_buffers(ifms_data, ifms)
    assert len(ifms) > 0