aboutsummaryrefslogtreecommitdiff
path: root/driver_library/python/test/test_driver.py
diff options
context:
space:
mode:
Diffstat (limited to 'driver_library/python/test/test_driver.py')
-rw-r--r--driver_library/python/test/test_driver.py86
1 files changed, 52 insertions, 34 deletions
diff --git a/driver_library/python/test/test_driver.py b/driver_library/python/test/test_driver.py
index 0dd207f..e9cb5c8 100644
--- a/driver_library/python/test/test_driver.py
+++ b/driver_library/python/test/test_driver.py
@@ -15,11 +15,14 @@ def device(device_name):
@pytest.fixture()
-def network_buffer(device, model_name, shared_data_folder):
+def network_file(model_name, shared_data_folder):
network_file = os.path.join(shared_data_folder, model_name)
- network_buffer = driver.Buffer(device, network_file)
- yield network_buffer
+ yield network_file
+@pytest.fixture()
+def network(device, network_file):
+ network = driver.Network(device, network_file)
+ yield network
@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_check_device_swig_ownership(device):
@@ -44,12 +47,33 @@ def test_device_wrong_name(device_name):
@pytest.mark.parametrize('device_name', ['ethosu0'])
-def test_driver_network_filenotfound_exception(device, shared_data_folder):
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_driver_network_from_bytearray(device, network_file):
+ network_data = None
+ with open(network_file, 'rb') as file:
+ network_data = file.read()
+ network = driver.Network(device, network_data)
- network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_driver_network_from_empty_bytearray(device):
with pytest.raises(RuntimeError) as err:
- network_buffer = driver.Buffer(device, network_file)
+ network = driver.Network(device, bytearray())
+
+ assert 'Failed to create the network, networkSize is zero' in str(err.value)
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['model.tflite'])
+def test_driver_network_from_file(device, network_file):
+ network = driver.Network(device, network_file)
+
+
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+@pytest.mark.parametrize('model_name', ['some_unknown_model.tflite'])
+def test_driver_network_filenotfound_exception(device, network_file):
+ with pytest.raises(RuntimeError) as err:
+ network = driver.Network(device, network_file)
# Only check for part of the exception since the exception returns
# absolute path which will change on different machines.
@@ -58,57 +82,51 @@ def test_driver_network_filenotfound_exception(device, shared_data_folder):
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_swig_ownership(network_buffer):
+def test_check_network_swig_ownership(network):
# Check to see that SWIG has ownership for parser. This instructs SWIG to take
# ownership of the return value. This allows the value to be automatically
# garbage-collected when it is no longer in use
- assert network_buffer.thisown
+ assert network.thisown
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_size(network_buffer):
- assert network_buffer.size() > 0
+def test_check_network_ifm_size(device, network):
+ assert network.getIfmSize() > 0
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_clear(network_buffer):
- network_buffer.clear()
- for i in range(network_buffer.size()):
- assert network_buffer.data()[i] == 0
+def test_check_network_ofm_size(device, network):
+ assert network.getOfmSize() > 0
@pytest.mark.parametrize('device_name', ['ethosu0'])
-@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_buffer_getFd(network_buffer):
- assert network_buffer.getFd() >= 0
+def test_check_buffer_swig_ownership(device):
+ buffer = driver.Buffer(device, 1024)
+ assert buffer.thisown
@pytest.mark.parametrize('device_name', ['ethosu0'])
-@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_network_ifm_size(device, network_buffer):
- network = driver.Network(device, network_buffer)
- assert network.getIfmSize() > 0
- assert network_buffer.thisown
-
-
-@pytest.mark.parametrize('device_name', [('ethosu0')])
-def test_check_network_buffer_none(device):
+def test_check_buffer_getFd(device):
+ buffer = driver.Buffer(device, 1024)
+ assert buffer.getFd() >= 0
- with pytest.raises(RuntimeError) as err:
- driver.Network(device, None)
- # Only check for part of the exception since the exception returns
- # absolute path which will change on different machines.
- assert 'Failed to create the network' in str(err.value)
+@pytest.mark.parametrize('device_name', ['ethosu0'])
+def test_check_buffer_size(device):
+ buffer = driver.Buffer(device, 1024)
+ assert buffer.size() == 1024
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
-def test_check_network_ofm_size(device, network_buffer):
- network = driver.Network(device, network_buffer)
- assert network.getOfmSize() > 0
+def test_check_buffer_clear(device, network_file):
+ buffer = driver.Buffer(device, network_file)
+
+ buffer.clear()
+ for i in range(buffer.size()):
+ assert buffer.data()[i] == 0
def test_getMaxPmuEventCounters():