aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_nng_mapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_nng_mapping.py')
-rw-r--r--ethosu/vela/test/test_nng_mapping.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_nng_mapping.py b/ethosu/vela/test/test_nng_mapping.py
new file mode 100644
index 00000000..08d77fea
--- /dev/null
+++ b/ethosu/vela/test/test_nng_mapping.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Unit tests for the mapping of TFLite or TOSA to NNG
+import pytest
+
+from ethosu.vela.tflite_mapping import builtin_operator_map
+from ethosu.vela.tosa_mapping import tosa_operator_map
+
+
+class TestNNGMapping:
+ """Ensure the mappings from TFLite to NNG are consistent."""
+
+ @pytest.mark.parametrize(
+ "operator_map",
+ ((builtin_operator_map), (tosa_operator_map)),
+ ids=("test_tflite_indices_match_nng", "test_tosa_indices_match_nng"),
+ )
+ def test_op_indices_match(self, operator_map):
+ """Ensure TFLite/TOSA indices and NNG indices are consistent for each operator."""
+ for map_op in operator_map.values():
+ op_type = map_op[0]
+ map_op_indices = map_op[-1] # TFLite/TOSA indices in last element of tuple
+
+ nng_indices = op_type.info.indices
+
+ for idx in range(3):
+ assert len(map_op_indices[idx]) == len(nng_indices[idx])