aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_tflite_graph.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-10-25 18:12:34 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-11-10 16:47:22 +0000
commite40a7adadd254e29d71af38f69a0a20ff4871eef (patch)
tree9a57ddf406846785683673565359d9bd6ba3cf0b /tests/test_nn_tensorflow_tflite_graph.py
parent720839a2dc6d4d75cd7aa77f83fcd49bcf114ba6 (diff)
downloadmlia-e40a7adadd254e29d71af38f69a0a20ff4871eef.tar.gz
MLIA-411 Report Cortex-A operator compatibility
Check input model for Arm NN TensorFlow Lite Delegate 22.08 support. Change-Id: I1253c4c0b294c5283e08f0a39561b922ef0f62e6
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_graph.py')
-rw-r--r--tests/test_nn_tensorflow_tflite_graph.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py
new file mode 100644
index 0000000..cd1fad6
--- /dev/null
+++ b/tests/test_nn_tensorflow_tflite_graph.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the tflite_graph module."""
+import json
+from pathlib import Path
+
+from mlia.nn.tensorflow.tflite_graph import Op
+from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import TensorInfo
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+from mlia.nn.tensorflow.tflite_graph import TFL_OP
+from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
+
+
+def test_tensor_info() -> None:
+ """Test class 'TensorInfo'."""
+ expected = {
+ "name": "Test",
+ "type": TFL_TYPE.INT8.name,
+ "shape": (1,),
+ "is_variable": False,
+ }
+ info = TensorInfo(**expected)
+ assert vars(info) == expected
+
+ expected = {
+ "name": "Test2",
+ "type": TFL_TYPE.FLOAT32.name,
+ "shape": [2, 3],
+ "is_variable": True,
+ }
+ tensor_dict = {
+ "name": [ord(c) for c in expected["name"]],
+ "type": TFL_TYPE[expected["type"]],
+ "shape": expected["shape"],
+ "is_variable": expected["is_variable"],
+ }
+ info = TensorInfo.from_dict(tensor_dict)
+ assert vars(info) == expected
+
+ json_repr = json.loads(repr(info))
+ assert vars(info) == json_repr
+
+ assert str(info)
+
+
+def test_op() -> None:
+ """Test class 'Op'."""
+ expected = {
+ "type": TFL_OP.CONV_2D.name,
+ "builtin_options": {},
+ "inputs": [],
+ "outputs": [],
+ "custom_type": None,
+ }
+ oper = Op(**expected)
+ assert vars(oper) == expected
+
+ expected["builtin_options"] = {"some_random_option": 3.14}
+ oper = Op(**expected)
+ assert vars(oper) == expected
+
+ activation_func = TFL_ACTIVATION_FUNCTION.RELU
+ expected["builtin_options"] = {"fused_activation_function": activation_func.value}
+ oper = Op(**expected)
+ assert oper.builtin_options
+ assert oper.builtin_options["fused_activation_function"] == activation_func.name
+
+ assert str(oper)
+ assert repr(oper)
+
+
+def test_parse_subgraphs(test_tflite_model: Path) -> None:
+ """Test function 'parse_subgraphs'."""
+ model = parse_subgraphs(test_tflite_model)
+ assert len(model) == 1
+ assert len(model[0]) == 5
+ for oper in model[0]:
+ assert TFL_OP[oper.type] in TFL_OP
+ assert len(oper.inputs) > 0
+ assert len(oper.outputs) > 0