From f418e832ffd5a10f549aa07a0c9c59406a374ffe Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Mon, 13 Nov 2023 10:23:32 +0100 Subject: MLBEDSW-8317: Add semantic checks for Transpose - Added semantic checks for Transpose - Added unit tests for semantic checks - Updated SUPPORTED_OPS.md Change-Id: I3fcf13120f4b6811f8de27711996cdb9c19c9847 Signed-off-by: Johan Alfven --- SUPPORTED_OPS.md | 4 +++- ethosu/vela/test/test_tflite_model_semantic.py | 26 ++++++++++++++++++++++++++ ethosu/vela/tflite_model_semantic.py | 22 ++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md index ceb0205..fd5e478 100644 --- a/SUPPORTED_OPS.md +++ b/SUPPORTED_OPS.md @@ -19,7 +19,7 @@ limitations under the License. # Supported Ops This file was automatically generated by Vela using the `--supported-ops-report` parameter. -Vela version: `3.9.1.dev21+gb724cdb.d20231107` +Vela version: `3.10.0rc2.dev0+ga8fda88.d20231113` This file complies with [**Gitiles Markdown syntax**](https://gerrit.googlesource.com/gitiles/+/HEAD/Documentation/markdown.md) @@ -410,6 +410,8 @@ This is a list of constraints that the SUB operator must satisfy in order to be This is a list of constraints that the TRANSPOSE operator must satisfy in order to be scheduled on the NPU. +- Permutation array must be a 1D tensor with RANK(IFM) elements +- Permutation array must have constant values in the range [0, RANK(IFM)) - The following shape/permutations are supported for transpose: When ifm rank is 2: WxC -> CxW When ifm rank is 3: HxWxC -> WxHxC, 1xWxC -> 1xCxW, Hx1xC -> Cx1xH diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 7ca1bbd..bea004a 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -647,3 +647,29 @@ def test_lstm_semantics(): op.inputs.pop() # Test restored valid configuration assert semantic_checker.is_operator_semantic_valid(op) + + +def test_transpose_semantics(): + # Test valid op + ifm = Tensor([2, 4], DataType.int8, "ifm") + perm = create_const_tensor("perm", [2], DataType.int32, [1, 0]) + ofm = Tensor([4, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation size + perm = create_const_tensor("perm", [3], DataType.int32, [1, 0]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [2, 0]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [0, -1]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [1, 0]) + perm.values = None + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index d9ace1e..eff40bc 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -209,6 +209,10 @@ class TFLiteSemantic: # Exp specific checks self.specific_constraints[Op.Exp].append(TFLiteSemantic.constraint_input_signed) + # Transpose specific checks + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_size) + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_values) + def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -833,6 +837,24 @@ class TFLiteSemantic: extra = ", ".join(extra) return valid, f"Op has non-variable state tensor(s): {extra}" + @staticmethod + def constraint_transpose_permutation_size(op): + "Permutation array must be a 1D tensor with RANK(IFM) elements" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = len(perm.shape) == 1 and perm.shape[0] == dims + return valid, f"Op has ifm_dimension={dims} and permutation shape {perm.shape}" + + @staticmethod + def constraint_transpose_permutation_values(op): + "Permutation array must have constant values in the range [0, RANK(IFM))" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = False + if perm.values is not None: + valid = not any([val < 0 or val >= dims for val in perm.values]) + return valid, f"Op has ifm_dimension={dims} and permutation values are: {perm.values}" + def tflite_semantic_checker(nng): semantic_checker = TFLiteSemantic() -- cgit v1.2.1