diff options
Diffstat (limited to 'driver_library/python/test/test_driver_utilities.py')
-rw-r--r-- | driver_library/python/test/test_driver_utilities.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/driver_library/python/test/test_driver_utilities.py b/driver_library/python/test/test_driver_utilities.py new file mode 100644 index 0000000..fc8e921 --- /dev/null +++ b/driver_library/python/test/test_driver_utilities.py @@ -0,0 +1,77 @@ +# +# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-License-Identifier: Apache-2.0 +# +import pytest +import os +import ethosu_driver as driver +from ethosu_driver.inference_runner import read_npy_file_to_buf + + +@pytest.fixture() +def device(device_name): + device = driver.open_device(device_name) + yield device + + +@pytest.fixture() +def network(device, model_name, shared_data_folder): + network_file = os.path.join(shared_data_folder, model_name) + network = driver.load_model(device, network_file) + yield network + + +@pytest.mark.parametrize('device_name', ['blabla']) +def test_open_device_wrong_name(device_name): + with pytest.raises(RuntimeError) as err: + device = driver.open_device(device_name) + # Only check for part of the exception since the exception returns + # absolute path which will change on different machines. + assert 'Failed to open device' in str(err.value) + + +@pytest.mark.parametrize('device_name', ['ethosu0']) +def test_network_filenotfound_exception(device, shared_data_folder): + + network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite") + + with pytest.raises(RuntimeError) as err: + driver.load_model(device, network_file) + + # Only check for part of the exception since the exception returns + # absolute path which will change on different machines. + assert 'Failed to open file:' in str(err.value) + + +@pytest.mark.parametrize('device_name', ['ethosu0']) +@pytest.mark.parametrize('model_name', ['model.tflite']) +def test_check_network_ifm_size(network): + assert network.getIfmSize() > 0 + + +@pytest.mark.parametrize('device_name', ['ethosu0']) +def test_allocate_buffers(device): + buffers = driver.allocate_buffers(device, [128, 256]) + assert len(buffers) == 2 + assert buffers[0].size() == 0 + assert buffers[0].capacity() == 128 + assert buffers[1].size() == 0 + assert buffers[1].capacity() == 256 + + +@pytest.mark.parametrize('device_name', ['ethosu0']) +@pytest.mark.parametrize('model_name', ['model.tflite']) +@pytest.mark.parametrize('ifms_file_list', [['model_ifm.npy']]) +def test_set_ifm_buffers(device, network, ifms_file_list, shared_data_folder): + full_path_input_files = [] + for input_file in ifms_file_list: + full_path_input_files.append(os.path.join(shared_data_folder, input_file)) + + ifms_data = [] + for ifm_file in full_path_input_files: + ifms_data.append(read_npy_file_to_buf(ifm_file)) + + ifms = driver.allocate_buffers(device, network.getIfmDims()) + driver.populate_buffers(ifms_data, ifms) + assert len(ifms) > 0 + |