aboutsummaryrefslogtreecommitdiff
path: root/driver_library/python/test/test_driver.py
diff options
context:
space:
mode:
authorMikael Olsson <mikael.olsson@arm.com>2023-10-30 11:10:56 +0100
committerMikael Olsson <mikael.olsson@arm.com>2023-11-06 09:36:00 +0100
commitc081e5954cd92165b139488e76bdfef1402acee6 (patch)
tree32bc237c124e21f12287150cba040c87c8e8b7e3 /driver_library/python/test/test_driver.py
parent9c999fdd40c0bf2ae420f6f3bfe013dc6baa73c1 (diff)
downloadethos-u-linux-driver-stack-c081e5954cd92165b139488e76bdfef1402acee6.tar.gz
Change create network UAPI to take a user buffer
To not allow the buffer for a network instance to be changed after creation, the create network UAPI will now take the network model data as a user buffer. The content of the user buffer is copied into an internally allocated DMA buffer that cannot be accessed by the user. This breaks the current API so the Linux kernel NPU driver version and the driver library version have been given major version bumps. All the tests, documentation and other applications affected by the changes have been updated accordingly. Change-Id: I25c785d75a24794c3db632e4abe5cfbb1c7ac190 Signed-off-by: Mikael Olsson <mikael.olsson@arm.com>
Diffstat (limited to 'driver_library/python/test/test_driver.py')
-rw-r--r--driver_library/python/test/test_driver.py86
1 files changed, 52 insertions, 34 deletions
diff --git a/driver_library/python/test/test_driver.py b/driver_library/python/test/test_driver.py
index 0dd207f..e9cb5c8 100644
--- a/driver_library/python/test/test_driver.py
+++ b/driver_library/python/test/test_driver.py
@@ -15,11 +15,14 @@ def device(device_name):
@pytest.fixture()
-def network_buffer(device, model_name, shared_data_folder):
+def network_file(model_name, shared_data_folder):
network_file = os.path.join(shared_data_folder, model_name)
- network_buffer = driver.Buffer(device, network_file)
- yield network_buffer
+ yield network_file
+@pytest.fixture()
+def network(device, network_file):
+ network = driver.Network(device, network_file)
+ yield network
@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_check_device_swig_ownership(device):
@@ -44,12 +47,33 @@ def test_device_wrong_name(device_name):
@pytest.mark.parametrize('device_name', ['ethosu0'])
-def test_driver_network_filenotfound_exception(device, shared_data_folder):
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_driver_network_from_bytearray(device, network_file):
+ network_data = None
+ with open(network_file, 'rb') as file:
+ network_data = file.read()
+ network = driver.Network(device, network_data)
- network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_driver_network_from_empty_bytearray(device):
with pytest.raises(RuntimeError) as err:
- network_buffer = driver.Buffer(device, network_file)
+ network = driver.Network(device, bytearray())
+
+ assert 'Failed to create the network, networkSize is zero' in str(err.value)
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_driver_network_from_file(device, network_file):
+ network = driver.Network(device, network_file)
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['some_unknown_model.tflite'])
+def test_driver_network_filenotfound_exception(device, network_file):
+ with pytest.raises(RuntimeError) as err:
+ network = driver.Network(device, network_file)
# Only check for part of the exception since the exception returns
# absolute path which will change on different machines.
@@ -58,57 +82,51 @@ def test_driver_network_filenotfound_exception(device, shared_data_folder):
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_swig_ownership(network_buffer):
+def test_check_network_swig_ownership(network):
# Check to see that SWIG has ownership for parser. This instructs SWIG to take
# ownership of the return value. This allows the value to be automatically
# garbage-collected when it is no longer in use
- assert network_buffer.thisown
+ assert network.thisown
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_size(network_buffer):
- assert network_buffer.size() > 0
+def test_check_network_ifm_size(device, network):
+ assert network.getIfmSize() > 0
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_clear(network_buffer):
- network_buffer.clear()
- for i in range(network_buffer.size()):
- assert network_buffer.data()[i] == 0
+def test_check_network_ofm_size(device, network):
+ assert network.getOfmSize() > 0
@pytest.mark.parametrize('device_name', ['ethosu0'])
-@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_getFd(network_buffer):
- assert network_buffer.getFd() >= 0
+def test_check_buffer_swig_ownership(device):
+ buffer = driver.Buffer(device, 1024)
+ assert buffer.thisown
@pytest.mark.parametrize('device_name', ['ethosu0'])
-@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_network_ifm_size(device, network_buffer):
- network = driver.Network(device, network_buffer)
- assert network.getIfmSize() > 0
- assert network_buffer.thisown
-
-
-@pytest.mark.parametrize('device_name', [('ethosu0')])
-def test_check_network_buffer_none(device):
+def test_check_buffer_getFd(device):
+ buffer = driver.Buffer(device, 1024)
+ assert buffer.getFd() >= 0
- with pytest.raises(RuntimeError) as err:
- driver.Network(device, None)
- # Only check for part of the exception since the exception returns
- # absolute path which will change on different machines.
- assert 'Failed to create the network' in str(err.value)
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_check_buffer_size(device):
+ buffer = driver.Buffer(device, 1024)
+ assert buffer.size() == 1024
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_network_ofm_size(device, network_buffer):
- network = driver.Network(device, network_buffer)
- assert network.getOfmSize() > 0
+def test_check_buffer_clear(device, network_file):
+ buffer = driver.Buffer(device, network_file)
+
+ buffer.clear()
+ for i in range(buffer.size()):
+ assert buffer.data()[i] == 0
def test_getMaxPmuEventCounters():