aboutsummaryrefslogtreecommitdiff
path: root/driver_library/python/test/test_driver.py
diff options
context:
space:
mode:
Diffstat (limited to 'driver_library/python/test/test_driver.py')
-rw-r--r--driver_library/python/test/test_driver.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/driver_library/python/test/test_driver.py b/driver_library/python/test/test_driver.py
new file mode 100644
index 0000000..5496aed
--- /dev/null
+++ b/driver_library/python/test/test_driver.py
@@ -0,0 +1,179 @@
+#
+# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-License-Identifier: Apache-2.0
+#
+import pytest
+import os
+import ethosu_driver as driver
+from ethosu_driver.inference_runner import read_npy_file_to_buf
+
+
+@pytest.fixture()
+def device(device_name):
+ device = driver.Device("/dev/{}".format(device_name))
+ yield device
+
+
+@pytest.fixture()
+def network_buffer(device, model_name, shared_data_folder):
+ network_file = os.path.join(shared_data_folder, model_name)
+ network_buffer = driver.Buffer(device, network_file)
+ yield network_buffer
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_check_device_swig_ownership(device):
+ # Check to see that SWIG has ownership for parser. This instructs SWIG to take
+ # ownership of the return value. This allows the value to be automatically
+ # garbage-collected when it is no longer in use
+ assert device.thisown
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_device_ping(device):
+ device.ping()
+
+
+@pytest.mark.parametrize('device_name', ['blabla'])
+def test_device_wrong_name(device_name):
+ with pytest.raises(RuntimeError) as err:
+ driver.Device("/dev/{}".format(device_name))
+ # Only check for part of the exception since the exception returns
+ # absolute path which will change on different machines.
+ assert 'Failed to open device' in str(err.value)
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_driver_network_filenotfound_exception(device, shared_data_folder):
+
+ network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
+
+ with pytest.raises(RuntimeError) as err:
+ network_buffer = driver.Buffer(device, network_file)
+
+ # Only check for part of the exception since the exception returns
+ # absolute path which will change on different machines.
+ assert 'Failed to open file:' in str(err.value)
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_buffer_swig_ownership(network_buffer):
+ # Check to see that SWIG has ownership for parser. This instructs SWIG to take
+ # ownership of the return value. This allows the value to be automatically
+ # garbage-collected when it is no longer in use
+ assert network_buffer.thisown
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_buffer_capacity(network_buffer):
+ assert network_buffer.capacity() > 0
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_buffer_size(network_buffer):
+ assert network_buffer.size() > 0
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_buffer_clear(network_buffer):
+ network_buffer.clear()
+ assert network_buffer.size() == 0
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_buffer_resize(network_buffer):
+ offset = 1
+ new_size = network_buffer.capacity() - offset
+ network_buffer.resize(new_size, offset)
+ assert network_buffer.size() == new_size
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_buffer_getFd(network_buffer):
+ assert network_buffer.getFd() >= 0
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_network_ifm_size(device, network_buffer):
+ network = driver.Network(device, network_buffer)
+ assert network.getIfmSize() > 0
+ assert network_buffer.thisown
+
+
+@pytest.mark.parametrize('device_name', [('ethosu0')])
+def test_check_network_buffer_none(device):
+
+ with pytest.raises(RuntimeError) as err:
+ driver.Network(device, None)
+
+ # Only check for part of the exception since the exception returns
+ # absolute path which will change on different machines.
+ assert 'Failed to create the network' in str(err.value)
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_check_network_ofm_size(device, network_buffer):
+ network = driver.Network(device, network_buffer)
+ assert network.getOfmSize() > 0
+
+
+def test_getMaxPmuEventCounters():
+ assert driver.Inference.getMaxPmuEventCounters() > 0
+
+
+@pytest.fixture()
+def inf(device_name, model_name, input_files, timeout, shared_data_folder):
+ # Prepate full path of model and inputs
+ full_path_model_file = os.path.join(shared_data_folder, model_name)
+ full_path_input_files = []
+ for input_file in input_files:
+ full_path_input_files.append(os.path.join(shared_data_folder, input_file))
+
+ ifms_data = []
+ for ifm_file in full_path_input_files:
+ ifms_data.append(read_npy_file_to_buf(ifm_file))
+
+ device = driver.open_device(device_name)
+ device.ping()
+ network = driver.load_model(device, full_path_model_file)
+ ofms = driver.allocate_buffers(device, network.getOfmDims())
+ ifms = driver.allocate_buffers(device, network.getIfmDims())
+
+ # ofm_buffers = runner.run(ifms_data,timeout, ethos_pmu_counters)
+ driver.populate_buffers(ifms_data, ifms)
+ ethos_pmu_counters = [1]
+ enable_cycle_counter = True
+ inf_inst = driver.Inference(network, ifms, ofms, ethos_pmu_counters, enable_cycle_counter)
+ inf_inst.wait(int(timeout))
+
+ yield inf_inst
+
+
+@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
+ [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
+def test_inf_get_cycle_counter(inf):
+ total_cycles = inf.getCycleCounter()
+ assert total_cycles >= 0
+
+
+@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
+ [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
+def test_inf_get_pmu_counters(inf):
+ inf_pmu_counter = inf.getPmuCounters()
+ assert len(inf_pmu_counter) > 0
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_capabilities(device):
+ cap = device.capabilities()
+ assert cap.hwId
+ assert cap.hwCfg
+ assert cap.driver