aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_tflite_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/test/test_tflite_parser.py')
-rw-r--r--python/pyarmnn/test/test_tflite_parser.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/python/pyarmnn/test/test_tflite_parser.py b/python/pyarmnn/test/test_tflite_parser.py
index 344ec7ca13..8735eef3fb 100644
--- a/python/pyarmnn/test/test_tflite_parser.py
+++ b/python/pyarmnn/test/test_tflite_parser.py
@@ -7,6 +7,12 @@ import pyarmnn as ann
import numpy as np
+def test_TfLiteParserOptions_default_values():
+ parserOptions = ann.TfLiteParserOptions()
+ assert parserOptions.m_InferAndValidate == False
+ assert parserOptions.m_StandInLayerForUnsupported == False
+
+
@pytest.fixture()
def parser(shared_data_folder):
"""
@@ -30,6 +36,42 @@ def test_check_tflite_parser_swig_ownership(parser):
assert parser.thisown
+def test_tflite_parser_with_optional_options():
+ parserOptions = ann.TfLiteParserOptions()
+ parserOptions.m_InferAndValidate = True
+ parser = ann.ITfLiteParser(parserOptions)
+ assert parser.thisown
+
+
+def create_with_opt() :
+ parserOptions = ann.TfLiteParserOptions()
+ parserOptions.m_InferAndValidate = True
+ return ann.ITfLiteParser(parserOptions)
+
+def test_tflite_parser_with_optional_options_out_of_scope(shared_data_folder):
+ parser = create_with_opt()
+ network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, "mock_model.tflite"))
+
+ graphs_count = parser.GetSubgraphCount()
+ graph_id = graphs_count - 1
+
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+
+ preferred_backends = [ann.BackendId('CpuAcc'), ann.BackendId('CpuRef')]
+
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions())
+ assert 0 == len(messages)
+
+ net_id, messages = runtime.LoadNetwork(opt_network)
+ assert "" == messages
+
+
def test_tflite_get_sub_graph_count(parser):
graphs_count = parser.GetSubgraphCount()
assert graphs_count == 1