diff options
author | Johan Alfven <johan.alfven@arm.com> | 2023-11-13 10:23:32 +0100 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-11-13 15:08:12 +0100 |
commit | f418e832ffd5a10f549aa07a0c9c59406a374ffe (patch) | |
tree | b91d152f48c346bee172a0f554deb0570956c44f | |
parent | a8fda88bced0d11441467b6798885101d41ac8e9 (diff) | |
download | ethos-u-vela-f418e832ffd5a10f549aa07a0c9c59406a374ffe.tar.gz |
MLBEDSW-8317: Add semantic checks for Transpose
- Added semantic checks for Transpose
- Added unit tests for semantic checks
- Updated SUPPORTED_OPS.md
Change-Id: I3fcf13120f4b6811f8de27711996cdb9c19c9847
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r-- | SUPPORTED_OPS.md | 4 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 26 | ||||
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 22 |
3 files changed, 51 insertions, 1 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md index ceb02051..fd5e478b 100644 --- a/SUPPORTED_OPS.md +++ b/SUPPORTED_OPS.md @@ -19,7 +19,7 @@ limitations under the License. # Supported Ops This file was automatically generated by Vela using the `--supported-ops-report` parameter. -Vela version: `3.9.1.dev21+gb724cdb.d20231107` +Vela version: `3.10.0rc2.dev0+ga8fda88.d20231113` This file complies with [**Gitiles Markdown syntax**](https://gerrit.googlesource.com/gitiles/+/HEAD/Documentation/markdown.md) @@ -410,6 +410,8 @@ This is a list of constraints that the SUB operator must satisfy in order to be This is a list of constraints that the TRANSPOSE operator must satisfy in order to be scheduled on the NPU. +- Permutation array must be a 1D tensor with RANK(IFM) elements +- Permutation array must have constant values in the range [0, RANK(IFM)) - The following shape/permutations are supported for transpose: When ifm rank is 2: WxC -> CxW When ifm rank is 3: HxWxC -> WxHxC, 1xWxC -> 1xCxW, Hx1xC -> Cx1xH diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 7ca1bbda..bea004ae 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -647,3 +647,29 @@ def test_lstm_semantics(): op.inputs.pop() # Test restored valid configuration assert semantic_checker.is_operator_semantic_valid(op) + + +def test_transpose_semantics(): + # Test valid op + ifm = Tensor([2, 4], DataType.int8, "ifm") + perm = create_const_tensor("perm", [2], DataType.int32, [1, 0]) + ofm = Tensor([4, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation size + perm = create_const_tensor("perm", [3], DataType.int32, [1, 0]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [2, 0]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [0, -1]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [1, 0]) + perm.values = None + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index d9ace1e8..eff40bc5 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -209,6 +209,10 @@ class TFLiteSemantic: # Exp specific checks self.specific_constraints[Op.Exp].append(TFLiteSemantic.constraint_input_signed) + # Transpose specific checks + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_size) + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_values) + def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -833,6 +837,24 @@ class TFLiteSemantic: extra = ", ".join(extra) return valid, f"Op has non-variable state tensor(s): {extra}" + @staticmethod + def constraint_transpose_permutation_size(op): + "Permutation array must be a 1D tensor with RANK(IFM) elements" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = len(perm.shape) == 1 and perm.shape[0] == dims + return valid, f"Op has ifm_dimension={dims} and permutation shape {perm.shape}" + + @staticmethod + def constraint_transpose_permutation_values(op): + "Permutation array must have constant values in the range [0, RANK(IFM))" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = False + if perm.values is not None: + valid = not any([val < 0 or val >= dims for val in perm.values]) + return valid, f"Op has ifm_dimension={dims} and permutation values are: {perm.values}" + def tflite_semantic_checker(nng): semantic_checker = TFLiteSemantic() |