# # SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates # SPDX-License-Identifier: Apache-2.0 # import pytest import os import ethosu_driver as driver from ethosu_driver.inference_runner import read_npy_file_to_buf @pytest.fixture() def device(device_name): device = driver.open_device(device_name) yield device @pytest.fixture() def network(device, model_name, shared_data_folder): network_file = os.path.join(shared_data_folder, model_name) network = driver.load_model(device, network_file) yield network @pytest.mark.parametrize('device_name', ['blabla']) def test_open_device_wrong_name(device_name): with pytest.raises(RuntimeError) as err: device = driver.open_device(device_name) # Only check for part of the exception since the exception returns # absolute path which will change on different machines. assert 'Failed to open device' in str(err.value) @pytest.mark.parametrize('device_name', ['ethosu0']) def test_network_filenotfound_exception(device, shared_data_folder): network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite") with pytest.raises(RuntimeError) as err: driver.load_model(device, network_file) # Only check for part of the exception since the exception returns # absolute path which will change on different machines. assert 'Failed to open file:' in str(err.value) @pytest.mark.parametrize('device_name', ['ethosu0']) @pytest.mark.parametrize('model_name', ['model.tflite']) def test_check_network_ifm_size(network): assert network.getIfmSize() > 0 @pytest.mark.parametrize('device_name', ['ethosu0']) def test_allocate_buffers(device): buffers = driver.allocate_buffers(device, [128, 256]) assert len(buffers) == 2 assert buffers[0].size() == 0 assert buffers[0].capacity() == 128 assert buffers[1].size() == 0 assert buffers[1].capacity() == 256 @pytest.mark.parametrize('device_name', ['ethosu0']) @pytest.mark.parametrize('model_name', ['model.tflite']) @pytest.mark.parametrize('ifms_file_list', [['model_ifm.npy']]) def test_set_ifm_buffers(device, network, ifms_file_list, shared_data_folder): full_path_input_files = [] for input_file in ifms_file_list: full_path_input_files.append(os.path.join(shared_data_folder, input_file)) ifms_data = [] for ifm_file in full_path_input_files: ifms_data.append(read_npy_file_to_buf(ifm_file)) ifms = driver.allocate_buffers(device, network.getIfmDims()) driver.populate_buffers(ifms_data, ifms) assert len(ifms) > 0