diff options
Diffstat (limited to 'python/pyarmnn/test/test_tflite_parser.py')
-rw-r--r-- | python/pyarmnn/test/test_tflite_parser.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/python/pyarmnn/test/test_tflite_parser.py b/python/pyarmnn/test/test_tflite_parser.py index 344ec7ca13..8735eef3fb 100644 --- a/python/pyarmnn/test/test_tflite_parser.py +++ b/python/pyarmnn/test/test_tflite_parser.py @@ -7,6 +7,12 @@ import pyarmnn as ann import numpy as np +def test_TfLiteParserOptions_default_values(): + parserOptions = ann.TfLiteParserOptions() + assert parserOptions.m_InferAndValidate == False + assert parserOptions.m_StandInLayerForUnsupported == False + + @pytest.fixture() def parser(shared_data_folder): """ @@ -30,6 +36,42 @@ def test_check_tflite_parser_swig_ownership(parser): assert parser.thisown +def test_tflite_parser_with_optional_options(): + parserOptions = ann.TfLiteParserOptions() + parserOptions.m_InferAndValidate = True + parser = ann.ITfLiteParser(parserOptions) + assert parser.thisown + + +def create_with_opt() : + parserOptions = ann.TfLiteParserOptions() + parserOptions.m_InferAndValidate = True + return ann.ITfLiteParser(parserOptions) + +def test_tflite_parser_with_optional_options_out_of_scope(shared_data_folder): + parser = create_with_opt() + network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, "mock_model.tflite")) + + graphs_count = parser.GetSubgraphCount() + graph_id = graphs_count - 1 + + input_names = parser.GetSubgraphInputTensorNames(graph_id) + input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) + + output_names = parser.GetSubgraphOutputTensorNames(graph_id) + + preferred_backends = [ann.BackendId('CpuAcc'), ann.BackendId('CpuRef')] + + options = ann.CreationOptions() + runtime = ann.IRuntime(options) + + opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions()) + assert 0 == len(messages) + + net_id, messages = runtime.LoadNetwork(opt_network) + assert "" == messages + + def test_tflite_get_sub_graph_count(parser): graphs_count = parser.GetSubgraphCount() assert graphs_count == 1 |