diff options
Diffstat (limited to 'driver_library/python/test/test_driver.py')
-rw-r--r-- | driver_library/python/test/test_driver.py | 86 |
1 files changed, 52 insertions, 34 deletions
diff --git a/driver_library/python/test/test_driver.py b/driver_library/python/test/test_driver.py index 0dd207f..e9cb5c8 100644 --- a/driver_library/python/test/test_driver.py +++ b/driver_library/python/test/test_driver.py @@ -15,11 +15,14 @@ def device(device_name): @pytest.fixture() -def network_buffer(device, model_name, shared_data_folder): +def network_file(model_name, shared_data_folder): network_file = os.path.join(shared_data_folder, model_name) - network_buffer = driver.Buffer(device, network_file) - yield network_buffer + yield network_file +@pytest.fixture() +def network(device, network_file): + network = driver.Network(device, network_file) + yield network @pytest.mark.parametrize('device_name', ['ethosu0']) def test_check_device_swig_ownership(device): @@ -44,12 +47,33 @@ def test_device_wrong_name(device_name): @pytest.mark.parametrize('device_name', ['ethosu0']) -def test_driver_network_filenotfound_exception(device, shared_data_folder): +@pytest.mark.parametrize('model_name', ['model.tflite']) +def test_driver_network_from_bytearray(device, network_file): + network_data = None + with open(network_file, 'rb') as file: + network_data = file.read() + network = driver.Network(device, network_data) - network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite") +@pytest.mark.parametrize('device_name', ['ethosu0']) +def test_driver_network_from_empty_bytearray(device): with pytest.raises(RuntimeError) as err: - network_buffer = driver.Buffer(device, network_file) + network = driver.Network(device, bytearray()) + + assert 'Failed to create the network, networkSize is zero' in str(err.value) + + +@pytest.mark.parametrize('device_name', ['ethosu0']) +@pytest.mark.parametrize('model_name', ['model.tflite']) +def test_driver_network_from_file(device, network_file): + network = driver.Network(device, network_file) + + +@pytest.mark.parametrize('device_name', ['ethosu0']) +@pytest.mark.parametrize('model_name', ['some_unknown_model.tflite']) +def test_driver_network_filenotfound_exception(device, network_file): + with pytest.raises(RuntimeError) as err: + network = driver.Network(device, network_file) # Only check for part of the exception since the exception returns # absolute path which will change on different machines. @@ -58,57 +82,51 @@ def test_driver_network_filenotfound_exception(device, shared_data_folder): @pytest.mark.parametrize('device_name', ['ethosu0']) @pytest.mark.parametrize('model_name', ['model.tflite']) -def test_check_buffer_swig_ownership(network_buffer): +def test_check_network_swig_ownership(network): # Check to see that SWIG has ownership for parser. This instructs SWIG to take # ownership of the return value. This allows the value to be automatically # garbage-collected when it is no longer in use - assert network_buffer.thisown + assert network.thisown @pytest.mark.parametrize('device_name', ['ethosu0']) @pytest.mark.parametrize('model_name', ['model.tflite']) -def test_check_buffer_size(network_buffer): - assert network_buffer.size() > 0 +def test_check_network_ifm_size(device, network): + assert network.getIfmSize() > 0 @pytest.mark.parametrize('device_name', ['ethosu0']) @pytest.mark.parametrize('model_name', ['model.tflite']) -def test_check_buffer_clear(network_buffer): - network_buffer.clear() - for i in range(network_buffer.size()): - assert network_buffer.data()[i] == 0 +def test_check_network_ofm_size(device, network): + assert network.getOfmSize() > 0 @pytest.mark.parametrize('device_name', ['ethosu0']) -@pytest.mark.parametrize('model_name', ['model.tflite']) -def test_check_buffer_getFd(network_buffer): - assert network_buffer.getFd() >= 0 +def test_check_buffer_swig_ownership(device): + buffer = driver.Buffer(device, 1024) + assert buffer.thisown @pytest.mark.parametrize('device_name', ['ethosu0']) -@pytest.mark.parametrize('model_name', ['model.tflite']) -def test_check_network_ifm_size(device, network_buffer): - network = driver.Network(device, network_buffer) - assert network.getIfmSize() > 0 - assert network_buffer.thisown - - -@pytest.mark.parametrize('device_name', [('ethosu0')]) -def test_check_network_buffer_none(device): +def test_check_buffer_getFd(device): + buffer = driver.Buffer(device, 1024) + assert buffer.getFd() >= 0 - with pytest.raises(RuntimeError) as err: - driver.Network(device, None) - # Only check for part of the exception since the exception returns - # absolute path which will change on different machines. - assert 'Failed to create the network' in str(err.value) +@pytest.mark.parametrize('device_name', ['ethosu0']) +def test_check_buffer_size(device): + buffer = driver.Buffer(device, 1024) + assert buffer.size() == 1024 @pytest.mark.parametrize('device_name', ['ethosu0']) @pytest.mark.parametrize('model_name', ['model.tflite']) -def test_check_network_ofm_size(device, network_buffer): - network = driver.Network(device, network_buffer) - assert network.getOfmSize() > 0 +def test_check_buffer_clear(device, network_file): + buffer = driver.Buffer(device, network_file) + + buffer.clear() + for i in range(buffer.size()): + assert buffer.data()[i] == 0 def test_getMaxPmuEventCounters(): |